diff --git a/core/ocr.py b/core/ocr.py index 1546d0e..c4c78b0 100644 --- a/core/ocr.py +++ b/core/ocr.py @@ -3,14 +3,21 @@ import cv2 import gc import time import threading -from rapidocr import RapidOCR +import numpy as np +from paddleocr import PaddleOCR from service.sensitive_service import get_all_sensitive_words +# 解决NumPy 1.20+版本中np.int已移除的兼容性问题 +try: + if not hasattr(np, 'int'): + np.int = int +except Exception as e: + print(f"处理NumPy兼容性时出错: {e}") + # 全局变量 _ocr_engine = None _forbidden_words = set() _conf_threshold = 0.5 -ocr_config_path = os.path.join(os.path.dirname(__file__), "config", "config.yaml") # 资源管理变量 _ref_count = 0 @@ -19,6 +26,9 @@ _lock = threading.Lock() _release_timeout = 5 # 30秒无使用则释放 _is_releasing = False # 标记是否正在释放 +# 并行处理配置 +_max_workers = 4 # 并行处理的线程数 + # 调试用计数器 _debug_counter = { "created": 0, @@ -35,9 +45,6 @@ def _release_engine(): try: _is_releasing = True - # 如果有释放方法则调用 - if hasattr(_ocr_engine, 'release'): - _ocr_engine.release() _ocr_engine = None _debug_counter["released"] += 1 print(f"OCR engine released. Stats: {_debug_counter}") @@ -52,8 +59,9 @@ def _release_engine(): except ImportError: pass try: - import tensorflow as tf - tf.keras.backend.clear_session() + import paddle + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.empty_cache() except ImportError: pass finally: @@ -61,12 +69,11 @@ def _release_engine(): def _monitor_thread(): - """监控线程、优化检查逻辑""" + """监控线程,优化检查逻辑""" global _ref_count, _last_used_time, _ocr_engine while True: time.sleep(5) # 每5秒检查一次 with _lock: - # 只有当引擎存在、没有引用且超时、才释放 if _ocr_engine and _ref_count == 0 and not _is_releasing: elapsed = time.time() - _last_used_time if elapsed > _release_timeout: @@ -91,25 +98,18 @@ def load_model(): print(f"Forbidden words load error: {e}") return False - # 验证配置文件 - if not os.path.exists(ocr_config_path): - print(f"OCR config not found: {ocr_config_path}") - return False - return True def detect(frame): - """OCR检测、优化引用计数管理""" - global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time + """OCR检测,支持并行处理""" + global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time, _max_workers # 验证前置条件 if not _forbidden_words: return (False, "违禁词未初始化") if frame is None or frame.size == 0: return (False, "无效帧数据") - if not os.path.exists(ocr_config_path): - return (False, f"OCR配置文件不存在: {ocr_config_path}") # 增加引用计数并获取引擎实例 engine = None @@ -121,15 +121,22 @@ def detect(frame): # 初始化引擎(如果未初始化且不在释放中) if not _ocr_engine and not _is_releasing: try: - _ocr_engine = RapidOCR(config_path=ocr_config_path) + # 初始化PaddleOCR,设置并行处理参数 + _ocr_engine = PaddleOCR( + use_angle_cls=True, + lang="ch", + show_log=False, + use_gpu=True, + max_text_length=1024, + threads=_max_workers + ) _debug_counter["created"] += 1 - print(f"OCR engine initialized. Stats: {_debug_counter}") + print(f"PaddleOCR engine initialized with {_max_workers} workers. Stats: {_debug_counter}") except Exception as e: print(f"OCR model load failed: {e}") - _ref_count -= 1 # 恢复引用计数 + _ref_count -= 1 return (False, f"引擎初始化失败: {str(e)}") - # 获取当前引擎引用 engine = _ocr_engine # 检查引擎是否可用 @@ -140,15 +147,56 @@ def detect(frame): try: # 执行OCR检测 - ocr_res = engine(frame) + ocr_res = engine.ocr(frame, cls=True) # 验证OCR结果格式 - if not ocr_res or not hasattr(ocr_res, 'txts') or not hasattr(ocr_res, 'scores'): + if not ocr_res or not isinstance(ocr_res, list): return (False, "无OCR结果") - # 处理OCR结果 - texts = [t.strip() for t in ocr_res.txts if t and isinstance(t, str)] - confs = [c for c in ocr_res.scores if c and isinstance(c, (int, float))] + # 处理OCR结果 - 兼容多种格式 + texts = [] + confs = [] + for line in ocr_res: + if line is None: + continue + + # 处理line可能是列表或直接是文本信息的情况 + if isinstance(line, list): + items_to_process = line + else: + items_to_process = [line] + + for item in items_to_process: + # 跳过纯数字列表(可能是坐标信息) + if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item): + continue + + # 处理元组形式的文本和置信度 (text, confidence) + if isinstance(item, tuple) and len(item) == 2: + text, conf = item + if isinstance(text, str) and isinstance(conf, (int, float)): + texts.append(text.strip()) + confs.append(float(conf)) + continue + + # 处理列表形式的[坐标信息, (text, confidence)] + if isinstance(item, list) and len(item) >= 2: + # 尝试从列表中提取文本和置信度 + text_data = item[1] + if isinstance(text_data, tuple) and len(text_data) == 2: + text, conf = text_data + if isinstance(text, str) and isinstance(conf, (int, float)): + texts.append(text.strip()) + confs.append(float(conf)) + continue + elif isinstance(text_data, str): + # 只有文本没有置信度的情况 + texts.append(text_data.strip()) + confs.append(1.0) # 默认最高置信度 + continue + + # 无法识别的格式,记录日志 + print(f"无法解析的OCR结果格式: {item}") if len(texts) != len(confs): return (False, "OCR结果格式异常") @@ -178,9 +226,16 @@ def detect(frame): return (False, f"检测错误: {str(e)}") finally: - # 减少引用计数、确保线程安全 + # 减少引用计数,确保线程安全 with _lock: _ref_count = max(0, _ref_count - 1) - # 持续使用时更新最后使用时间 if _ref_count > 0: _last_used_time = time.time() + + +def batch_detect(frames): + """批量检测接口,充分利用并行能力""" + results = [] + for frame in frames: + results.append(detect(frame)) + return results