From e21432c6a12cee47906616b4394541066d98ff17 Mon Sep 17 00:00:00 2001 From: ninghongbin <2409766686@qq.com> Date: Tue, 2 Sep 2025 21:30:28 +0800 Subject: [PATCH] ocr1.0 --- ocr/config/1.yaml | 119 +++++++++++++++++++++++++++++ ocr/forbidden_words.txt | 3 +- ocr/ocr_violation_detector.py | 138 ++++++++++++++++++++-------------- rtc/rtc.py | 21 +++++- 4 files changed, 219 insertions(+), 62 deletions(-) create mode 100644 ocr/config/1.yaml diff --git a/ocr/config/1.yaml b/ocr/config/1.yaml new file mode 100644 index 0000000..36ce9db --- /dev/null +++ b/ocr/config/1.yaml @@ -0,0 +1,119 @@ +Global: + text_score: 0.5 + + use_det: true + use_cls: true + use_rec: true + + min_height: 30 + width_height_ratio: 8 + max_side_len: 2000 + min_side_len: 30 + + return_word_box: false + return_single_char_box: false + + font_path: null + +EngineConfig: + onnxruntime: + intra_op_num_threads: -1 + inter_op_num_threads: -1 + enable_cpu_mem_arena: false + + cpu_ep_cfg: + arena_extend_strategy: "kSameAsRequested" + + use_cuda: true # 改为true以启用CUDA + cuda_ep_cfg: + device_id: 0 + arena_extend_strategy: "kNextPowerOfTwo" + cudnn_conv_algo_search: "EXHAUSTIVE" + do_copy_in_default_stream: true + + use_dml: false + dm_ep_cfg: null + + use_cann: false + cann_ep_cfg: + device_id: 0 + arena_extend_strategy: "kNextPowerOfTwo" + npu_mem_limit: 21474836480 # 20 * 1024 * 1024 * 1024 + op_select_impl_mode: "high_performance" + optypelist_for_implmode: "Gelu" + enable_cann_graph: true + + openvino: + inference_num_threads: -1 + performance_hint: null + performance_num_requests: -1 + enable_cpu_pinning: null + num_streams: -1 + enable_hyper_threading: null + scheduling_core_type: null + + paddle: + cpu_math_library_num_threads: -1 + use_npu: false + npu_id: 0 + use_cuda: true # 改为true以启用CUDA + gpu_id: 0 + gpu_mem: 500 + + torch: + use_cuda: true # 已经是true + gpu_id: 0 + +Det: + engine_type: "torch" + lang_type: "ch" + model_type: "mobile" + ocr_version: "PP-OCRv4" + + task_type: "det" + + model_path: null + model_dir: null + + limit_side_len: 736 + limit_type: min + std: [ 0.5, 0.5, 0.5 ] + mean: [ 0.5, 0.5, 0.5 ] + + thresh: 0.3 + box_thresh: 0.5 + max_candidates: 1000 + unclip_ratio: 1.6 + use_dilation: true + score_mode: fast + +Cls: + engine_type: "torch" + lang_type: "ch" + model_type: "mobile" + ocr_version: "PP-OCRv4" + + task_type: "cls" + + model_path: null + model_dir: null + + cls_image_shape: [3, 48, 192] + cls_batch_num: 6 + cls_thresh: 0.9 + label_list: ["0", "180"] + +Rec: + engine_type: "torch" + lang_type: "ch" + model_type: "mobile" + ocr_version: "PP-OCRv4" + + task_type: "rec" + + model_path: null + model_dir: null + + rec_keys_path: null + rec_img_shape: [3, 48, 320] + rec_batch_num: 6 diff --git a/ocr/forbidden_words.txt b/ocr/forbidden_words.txt index 7636b04..1b96a31 100644 --- a/ocr/forbidden_words.txt +++ b/ocr/forbidden_words.txt @@ -14,4 +14,5 @@ 法轮大法好 法轮功大法好 法轮 -李洪志 \ No newline at end of file +李洪志 +习近平 \ No newline at end of file diff --git a/ocr/ocr_violation_detector.py b/ocr/ocr_violation_detector.py index f2e31d2..8bf5d7e 100644 --- a/ocr/ocr_violation_detector.py +++ b/ocr/ocr_violation_detector.py @@ -1,7 +1,7 @@ import os import cv2 +import logging from rapidocr import RapidOCR -from logger_config import logger class OCRViolationDetector: @@ -9,47 +9,110 @@ class OCRViolationDetector: 封装RapidOCR引擎,用于检测图像帧中的违禁词。 """ - def __init__(self, forbidden_words_path: str, ocr_confidence_threshold: float = 0.5): # 降低阈值提高检出率 + def __init__(self, forbidden_words_path: str, ocr_confidence_threshold: float = 0.5, + log_level: int = logging.INFO, log_file: str = None): """ - 初始化OCR引擎和违禁词列表。 + 初始化OCR引擎、违禁词列表和日志配置。 Args: forbidden_words_path (str): 违禁词列表 .txt 文件的路径。 ocr_confidence_threshold (float): OCR识别结果的置信度阈值。 + log_level (int): 日志级别,默认为logging.INFO + log_file (str, optional): 日志文件路径,如不提供则只输出到控制台 """ + # 初始化日志 + self.logger = self._setup_logger(log_level, log_file) + + # 加载违禁词 self.forbidden_words = self._load_forbidden_words(forbidden_words_path) + + # 初始化OCR引擎 self.ocr_engine = self._initialize_ocr() + + # 设置置信度阈值 self.OCR_CONFIDENCE_THRESHOLD = ocr_confidence_threshold + self.logger.info(f"OCR置信度阈值设置为: {ocr_confidence_threshold}") + + def _setup_logger(self, log_level: int, log_file: str = None) -> logging.Logger: + """ + 配置日志系统 + + Args: + log_level: 日志级别 + log_file: 日志文件路径,如为None则只输出到控制台 + + Returns: + 配置好的logger实例 + """ + # 创建logger + logger = logging.getLogger('OCRViolationDetector') + logger.setLevel(log_level) + + # 避免重复添加处理器 + if logger.handlers: + return logger + + # 定义日志格式 + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + # 添加控制台处理器 + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + # 如果提供了日志文件路径,则添加文件处理器 + if log_file: + try: + # 确保日志目录存在 + log_dir = os.path.dirname(log_file) + if log_dir and not os.path.exists(log_dir): + os.makedirs(log_dir) + + file_handler = logging.FileHandler(log_file, encoding='utf-8') + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + logger.info(f"日志文件将保存至: {log_file}") + except Exception as e: + logger.warning(f"无法创建日志文件处理器: {str(e)},仅输出至控制台") + + return logger def _load_forbidden_words(self, path): - """从txt文件加载违禁词列表(与rapidocr_test.py保持一致)""" + """从txt文件加载违禁词列表""" words = set() if not os.path.exists(path): - logger.warning(f"警告:未找到违禁词文件 {path},将跳过违禁词检测") + self.logger.warning(f"警告:未找到违禁词文件 {path},将跳过违禁词检测") return words try: with open(path, 'r', encoding='utf-8') as f: - # 去除每行首尾空格和换行符,过滤空行(不排除注释行,与测试代码统一) + # 去除每行首尾空格和换行符,过滤空行 words = {line.strip() for line in f if line.strip()} - logger.info(f"成功加载 {len(words)} 个违禁词。") + self.logger.info(f"成功加载 {len(words)} 个违禁词。") except Exception as e: - logger.error(f"加载违禁词文件失败:{str(e)},将跳过违禁词检测") + self.logger.error(f"加载违禁词文件失败:{str(e)},将跳过违禁词检测") return words def _initialize_ocr(self): """初始化RapidOCR引擎""" - logger.info("正在初始化RapidOCR引擎...") + self.logger.info("正在初始化RapidOCR引擎...") - config_path = r".\config\1.yaml" + config_path = r"../ocr/config/1.yaml" try: + # 检查配置文件是否存在 + if not os.path.exists(config_path): + self.logger.error(f"RapidOCR配置文件不存在: {config_path}") + return None + engine = RapidOCR( config_path=config_path ) - logger.info("RapidOCR引擎初始化成功。") + self.logger.info("RapidOCR引擎初始化成功。") return engine except Exception as e: - logger.error(f"RapidOCR引擎初始化失败: {e}") + self.logger.error(f"RapidOCR引擎初始化失败: {e}") return None def detect(self, frame): @@ -57,6 +120,7 @@ class OCRViolationDetector: 对单帧图像进行OCR,检测所有出现的违禁词并返回列表 返回格式:(是否有违禁词, 违禁词列表, 对应的置信度列表) """ + print("收到帧") if not self.ocr_engine or not self.forbidden_words: return False, [], [] @@ -64,9 +128,9 @@ class OCRViolationDetector: all_confidences = [] # 存储对应违禁词的置信度 try: - # 执行OCR识别(修正调用方式,与测试代码一致) + # 执行OCR识别 result = self.ocr_engine(frame) - logger.debug(f"RapidOCR 原始返回结果: {result}") + self.logger.debug(f"RapidOCR 原始返回结果: {result}") if result is None: return False, [], [] @@ -78,59 +142,19 @@ class OCRViolationDetector: # 遍历所有识别结果,收集所有违禁词 for text, conf in zip(texts, confidences): if conf < self.OCR_CONFIDENCE_THRESHOLD: - logger.debug(f"文本 '{text}' 置信度 {conf:.4f} 低于阈值,跳过") + self.logger.debug(f"文本 '{text}' 置信度 {conf:.4f} 低于阈值,跳过") continue # 检查当前文本中是否包含多个违禁词 for word in self.forbidden_words: if word in text: - logger.warning(f"OCR检测到违禁词: '{word}' (来自文本: '{text}') 置信度: {conf:.4f}") + self.logger.warning(f"OCR检测到违禁词: '{word}' (来自文本: '{text}') 置信度: {conf:.4f}") all_prohibited.append(word) all_confidences.append(conf) except Exception as e: - logger.error(f"OCR检测过程中发生错误: {e}", exc_info=True) - + self.logger.error(f"OCR检测过程中发生错误: {e}", exc_info=True) # 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表) return len(all_prohibited) > 0, all_prohibited, all_confidences - - - -# def test_single_image(): -# """测试单张图片的OCR违规检测(显示所有违禁词)""" -# TEST_IMAGE_PATH = r"ocr/images/img_7.png" # 修正路径格式 -# FORBIDDEN_WORDS_PATH = r"ocr/forbidden_words.txt" -# CONFIDENCE_THRESHOLD = 0.5 -# -# detector = OCRViolationDetector( -# forbidden_words_path=FORBIDDEN_WORDS_PATH, -# ocr_confidence_threshold=CONFIDENCE_THRESHOLD -# ) -# -# if not os.path.exists(TEST_IMAGE_PATH): -# print(f"错误:图片文件不存在 - {TEST_IMAGE_PATH}") -# return -# -# frame = cv2.imread(TEST_IMAGE_PATH) -# if frame is None: -# print(f"错误:无法读取图片 - {TEST_IMAGE_PATH}") -# return -# -# # 执行检测 -# has_violation, words, confidences = detector.detect(frame) -# -# # 输出所有检测到的违禁词 -# if has_violation: -# print(f"测试结果:图片中共检测到 {len(words)} 个违禁词:") -# for word, conf in zip(words, confidences): -# print(f"- {word}(置信度:{conf:.4f})") -# else: -# print("测试结果:图片中未检测到违禁词") -# -# -# if __name__ == "__main__": -# print("开始单张图片OCR违规检测测试...") -# test_single_image() -# print("测试完成") \ No newline at end of file diff --git a/rtc/rtc.py b/rtc/rtc.py index aa4a78d..a160765 100644 --- a/rtc/rtc.py +++ b/rtc/rtc.py @@ -4,6 +4,7 @@ import cv2 # 导入OpenCV库 import numpy as np from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration from aiortc.mediastreams import MediaStreamTrack +from ocr.ocr_violation_detector import OCRViolationDetector class VideoTrack(MediaStreamTrack): @@ -47,7 +48,7 @@ async def rtc_frame_receiver(url, frame_queue): if frame_queue.empty(): # 队列为空、放入当前cv2帧 await frame_queue.put(frame_cv2) - print(f"第{total_frames}帧:队列为空、已放入新的cv2帧,尺寸: {frame_cv2.shape}") + # print(f"第{total_frames}帧:队列为空、已放入新的cv2帧,尺寸: {frame_cv2.shape}") else: # 队列非空、说明上一帧还未处理、跳过当前帧 print(f"第{total_frames}帧:队列非空、跳过该帧") @@ -93,23 +94,35 @@ async def frame_consumer(frame_queue): Args: frame_queue: 帧队列 """ + # 创建OCR检测器实例(请替换为实际的违禁词文件路径) + ocr_detector = OCRViolationDetector( + forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt", # 替换为实际路径 + ocr_confidence_threshold=0.5,) + while True: # 从队列中获取cv2帧(队列为空时会阻塞等待新帧) current_frame = await frame_queue.get() + ocr_detector.detect(current_frame) + + + + + + # 验证这是cv2可以处理的帧 - print(f"从队列获取到cv2帧、尺寸: {current_frame.shape}、数据类型: {current_frame.dtype}") + # print(f"从队列获取到cv2帧、尺寸: {current_frame.shape}、数据类型: {current_frame.dtype}") # 这里可以添加cv2的处理代码,例如显示帧 # cv2.imshow('Received Frame', current_frame) # if cv2.waitKey(1) & 0xFF == ord('q'): # break - print("cv2帧处理完成") + # print("cv2帧处理完成") # 标记任务完成 frame_queue.task_done() - print("帧处理完成、队列已清空") + # print("帧处理完成、队列已清空") async def main():