From 0416d6323a4064a4c616211b3d981f7cf78fef40 Mon Sep 17 00:00:00 2001 From: ninghongbin <2409766686@qq.com> Date: Tue, 2 Sep 2025 19:49:54 +0800 Subject: [PATCH] ocr1.0 --- ocr/forbidden_words.txt | 17 +++++ ocr/logger_config.py | 44 +++++++++++ ocr/ocr_violation_detector.py | 136 ++++++++++++++++++++++++++++++++++ 3 files changed, 197 insertions(+) create mode 100644 ocr/forbidden_words.txt create mode 100644 ocr/logger_config.py create mode 100644 ocr/ocr_violation_detector.py diff --git a/ocr/forbidden_words.txt b/ocr/forbidden_words.txt new file mode 100644 index 0000000..7636b04 --- /dev/null +++ b/ocr/forbidden_words.txt @@ -0,0 +1,17 @@ +打倒习近平 +打到习近平 +打倒毛泽东 +打到毛泽东 +打到主席 +打倒主席 +打到共产主义 +打倒共产主义 +打到共产党 +打倒共产党 +胖猫 +法轮功 +法轮大法 +法轮大法好 +法轮功大法好 +法轮 +李洪志 \ No newline at end of file diff --git a/ocr/logger_config.py b/ocr/logger_config.py new file mode 100644 index 0000000..038d657 --- /dev/null +++ b/ocr/logger_config.py @@ -0,0 +1,44 @@ +#日志文件 +import logging +import sys + +def setup_logger(): + """ + 配置一个全局日志记录器,支持输出到控制台和文件。 + """ + # 创建一个日志记录器 + logger = logging.getLogger("ViolationDetectorLogger") + logger.setLevel(logging.DEBUG) # 设置最低级别为DEBUG + + # 如果已经有处理器了,就不要重复添加,防止日志重复打印 + if logger.hasHandlers(): + return logger + + # --- 控制台处理器 --- + console_handler = logging.StreamHandler(sys.stdout) + # 对于控制台,我们只显示INFO及以上级别的信息 + console_handler.setLevel(logging.INFO) + console_formatter = logging.Formatter( + '%(asctime)s - %(levelname)s - [%(module)s:%(lineno)d] - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + console_handler.setFormatter(console_formatter) + + # --- 文件处理器 --- + file_handler = logging.FileHandler("violation_detector.log", mode='a', encoding='utf-8') + # 对于文件,我们记录所有DEBUG及以上级别的信息 + file_handler.setLevel(logging.DEBUG) + file_formatter = logging.Formatter( + '%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + file_handler.setFormatter(file_formatter) + + # 将处理器添加到日志记录器 + logger.addHandler(console_handler) + logger.addHandler(file_handler) + + return logger + +# 创建并导出logger实例 +logger = setup_logger() diff --git a/ocr/ocr_violation_detector.py b/ocr/ocr_violation_detector.py new file mode 100644 index 0000000..f2e31d2 --- /dev/null +++ b/ocr/ocr_violation_detector.py @@ -0,0 +1,136 @@ +import os +import cv2 +from rapidocr import RapidOCR +from logger_config import logger + + +class OCRViolationDetector: + """ + 封装RapidOCR引擎,用于检测图像帧中的违禁词。 + """ + + def __init__(self, forbidden_words_path: str, ocr_confidence_threshold: float = 0.5): # 降低阈值提高检出率 + """ + 初始化OCR引擎和违禁词列表。 + + Args: + forbidden_words_path (str): 违禁词列表 .txt 文件的路径。 + ocr_confidence_threshold (float): OCR识别结果的置信度阈值。 + """ + self.forbidden_words = self._load_forbidden_words(forbidden_words_path) + self.ocr_engine = self._initialize_ocr() + self.OCR_CONFIDENCE_THRESHOLD = ocr_confidence_threshold + + def _load_forbidden_words(self, path): + """从txt文件加载违禁词列表(与rapidocr_test.py保持一致)""" + words = set() + if not os.path.exists(path): + logger.warning(f"警告:未找到违禁词文件 {path},将跳过违禁词检测") + return words + + try: + with open(path, 'r', encoding='utf-8') as f: + # 去除每行首尾空格和换行符,过滤空行(不排除注释行,与测试代码统一) + words = {line.strip() for line in f if line.strip()} + logger.info(f"成功加载 {len(words)} 个违禁词。") + except Exception as e: + logger.error(f"加载违禁词文件失败:{str(e)},将跳过违禁词检测") + return words + + def _initialize_ocr(self): + """初始化RapidOCR引擎""" + logger.info("正在初始化RapidOCR引擎...") + + config_path = r".\config\1.yaml" + try: + engine = RapidOCR( + config_path=config_path + ) + logger.info("RapidOCR引擎初始化成功。") + return engine + except Exception as e: + logger.error(f"RapidOCR引擎初始化失败: {e}") + return None + + def detect(self, frame): + """ + 对单帧图像进行OCR,检测所有出现的违禁词并返回列表 + 返回格式:(是否有违禁词, 违禁词列表, 对应的置信度列表) + """ + if not self.ocr_engine or not self.forbidden_words: + return False, [], [] + + all_prohibited = [] # 存储所有检测到的违禁词 + all_confidences = [] # 存储对应违禁词的置信度 + + try: + # 执行OCR识别(修正调用方式,与测试代码一致) + result = self.ocr_engine(frame) + logger.debug(f"RapidOCR 原始返回结果: {result}") + + if result is None: + return False, [], [] + + # 提取文本和置信度(适配RapidOCR的结果格式) + texts = result.txts if hasattr(result, 'txts') else [] + confidences = result.scores if hasattr(result, 'scores') else [] + + # 遍历所有识别结果,收集所有违禁词 + for text, conf in zip(texts, confidences): + if conf < self.OCR_CONFIDENCE_THRESHOLD: + logger.debug(f"文本 '{text}' 置信度 {conf:.4f} 低于阈值,跳过") + continue + + # 检查当前文本中是否包含多个违禁词 + for word in self.forbidden_words: + if word in text: + logger.warning(f"OCR检测到违禁词: '{word}' (来自文本: '{text}') 置信度: {conf:.4f}") + all_prohibited.append(word) + all_confidences.append(conf) + + except Exception as e: + logger.error(f"OCR检测过程中发生错误: {e}", exc_info=True) + + + # 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表) + return len(all_prohibited) > 0, all_prohibited, all_confidences + + + + +# def test_single_image(): +# """测试单张图片的OCR违规检测(显示所有违禁词)""" +# TEST_IMAGE_PATH = r"ocr/images/img_7.png" # 修正路径格式 +# FORBIDDEN_WORDS_PATH = r"ocr/forbidden_words.txt" +# CONFIDENCE_THRESHOLD = 0.5 +# +# detector = OCRViolationDetector( +# forbidden_words_path=FORBIDDEN_WORDS_PATH, +# ocr_confidence_threshold=CONFIDENCE_THRESHOLD +# ) +# +# if not os.path.exists(TEST_IMAGE_PATH): +# print(f"错误:图片文件不存在 - {TEST_IMAGE_PATH}") +# return +# +# frame = cv2.imread(TEST_IMAGE_PATH) +# if frame is None: +# print(f"错误:无法读取图片 - {TEST_IMAGE_PATH}") +# return +# +# # 执行检测 +# has_violation, words, confidences = detector.detect(frame) +# +# # 输出所有检测到的违禁词 +# if has_violation: +# print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:") +# for word, conf in zip(words, confidences): +# print(f"- {word}(置信度:{conf:.4f})") +# else: +# print("测试结果:图片中未检测到违禁词") +# +# +# if __name__ == "__main__": +# print("开始单张图片OCR违规检测测试...") +# test_single_image() +# print("测试完成") \ No newline at end of file