paddleocr
This commit is contained in:
		
							
								
								
									
										113
									
								
								core/ocr.py
									
									
									
									
									
								
							
							
						
						
									
										113
									
								
								core/ocr.py
									
									
									
									
									
								
							| @ -3,14 +3,21 @@ import cv2 | |||||||
| import gc | import gc | ||||||
| import time | import time | ||||||
| import threading | import threading | ||||||
| from rapidocr import RapidOCR | import numpy as np | ||||||
|  | from paddleocr import PaddleOCR | ||||||
| from service.sensitive_service import get_all_sensitive_words | from service.sensitive_service import get_all_sensitive_words | ||||||
|  |  | ||||||
|  | # 解决NumPy 1.20+版本中np.int已移除的兼容性问题 | ||||||
|  | try: | ||||||
|  |     if not hasattr(np, 'int'): | ||||||
|  |         np.int = int | ||||||
|  | except Exception as e: | ||||||
|  |     print(f"处理NumPy兼容性时出错: {e}") | ||||||
|  |  | ||||||
| # 全局变量 | # 全局变量 | ||||||
| _ocr_engine = None | _ocr_engine = None | ||||||
| _forbidden_words = set() | _forbidden_words = set() | ||||||
| _conf_threshold = 0.5 | _conf_threshold = 0.5 | ||||||
| ocr_config_path = os.path.join(os.path.dirname(__file__), "config", "config.yaml") |  | ||||||
|  |  | ||||||
| # 资源管理变量 | # 资源管理变量 | ||||||
| _ref_count = 0 | _ref_count = 0 | ||||||
| @ -19,6 +26,9 @@ _lock = threading.Lock() | |||||||
| _release_timeout = 5  # 30秒无使用则释放 | _release_timeout = 5  # 30秒无使用则释放 | ||||||
| _is_releasing = False  # 标记是否正在释放 | _is_releasing = False  # 标记是否正在释放 | ||||||
|  |  | ||||||
|  | # 并行处理配置 | ||||||
|  | _max_workers = 4  # 并行处理的线程数 | ||||||
|  |  | ||||||
| # 调试用计数器 | # 调试用计数器 | ||||||
| _debug_counter = { | _debug_counter = { | ||||||
|     "created": 0, |     "created": 0, | ||||||
| @ -35,9 +45,6 @@ def _release_engine(): | |||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         _is_releasing = True |         _is_releasing = True | ||||||
|         # 如果有释放方法则调用 |  | ||||||
|         if hasattr(_ocr_engine, 'release'): |  | ||||||
|             _ocr_engine.release() |  | ||||||
|         _ocr_engine = None |         _ocr_engine = None | ||||||
|         _debug_counter["released"] += 1 |         _debug_counter["released"] += 1 | ||||||
|         print(f"OCR engine released. Stats: {_debug_counter}") |         print(f"OCR engine released. Stats: {_debug_counter}") | ||||||
| @ -52,8 +59,9 @@ def _release_engine(): | |||||||
|         except ImportError: |         except ImportError: | ||||||
|             pass |             pass | ||||||
|         try: |         try: | ||||||
|             import tensorflow as tf |             import paddle | ||||||
|             tf.keras.backend.clear_session() |             if paddle.is_compiled_with_cuda(): | ||||||
|  |                 paddle.device.cuda.empty_cache() | ||||||
|         except ImportError: |         except ImportError: | ||||||
|             pass |             pass | ||||||
|     finally: |     finally: | ||||||
| @ -61,12 +69,11 @@ def _release_engine(): | |||||||
|  |  | ||||||
|  |  | ||||||
| def _monitor_thread(): | def _monitor_thread(): | ||||||
|     """监控线程、优化检查逻辑""" |     """监控线程,优化检查逻辑""" | ||||||
|     global _ref_count, _last_used_time, _ocr_engine |     global _ref_count, _last_used_time, _ocr_engine | ||||||
|     while True: |     while True: | ||||||
|         time.sleep(5)  # 每5秒检查一次 |         time.sleep(5)  # 每5秒检查一次 | ||||||
|         with _lock: |         with _lock: | ||||||
|             # 只有当引擎存在、没有引用且超时、才释放 |  | ||||||
|             if _ocr_engine and _ref_count == 0 and not _is_releasing: |             if _ocr_engine and _ref_count == 0 and not _is_releasing: | ||||||
|                 elapsed = time.time() - _last_used_time |                 elapsed = time.time() - _last_used_time | ||||||
|                 if elapsed > _release_timeout: |                 if elapsed > _release_timeout: | ||||||
| @ -91,25 +98,18 @@ def load_model(): | |||||||
|         print(f"Forbidden words load error: {e}") |         print(f"Forbidden words load error: {e}") | ||||||
|         return False |         return False | ||||||
|  |  | ||||||
|     # 验证配置文件 |  | ||||||
|     if not os.path.exists(ocr_config_path): |  | ||||||
|         print(f"OCR config not found: {ocr_config_path}") |  | ||||||
|         return False |  | ||||||
|  |  | ||||||
|     return True |     return True | ||||||
|  |  | ||||||
|  |  | ||||||
| def detect(frame): | def detect(frame): | ||||||
|     """OCR检测、优化引用计数管理""" |     """OCR检测,支持并行处理""" | ||||||
|     global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time |     global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time, _max_workers | ||||||
|  |  | ||||||
|     # 验证前置条件 |     # 验证前置条件 | ||||||
|     if not _forbidden_words: |     if not _forbidden_words: | ||||||
|         return (False, "违禁词未初始化") |         return (False, "违禁词未初始化") | ||||||
|     if frame is None or frame.size == 0: |     if frame is None or frame.size == 0: | ||||||
|         return (False, "无效帧数据") |         return (False, "无效帧数据") | ||||||
|     if not os.path.exists(ocr_config_path): |  | ||||||
|         return (False, f"OCR配置文件不存在: {ocr_config_path}") |  | ||||||
|  |  | ||||||
|     # 增加引用计数并获取引擎实例 |     # 增加引用计数并获取引擎实例 | ||||||
|     engine = None |     engine = None | ||||||
| @ -121,15 +121,22 @@ def detect(frame): | |||||||
|         # 初始化引擎(如果未初始化且不在释放中) |         # 初始化引擎(如果未初始化且不在释放中) | ||||||
|         if not _ocr_engine and not _is_releasing: |         if not _ocr_engine and not _is_releasing: | ||||||
|             try: |             try: | ||||||
|                 _ocr_engine = RapidOCR(config_path=ocr_config_path) |                 # 初始化PaddleOCR,设置并行处理参数 | ||||||
|  |                 _ocr_engine = PaddleOCR( | ||||||
|  |                     use_angle_cls=True, | ||||||
|  |                     lang="ch", | ||||||
|  |                     show_log=False, | ||||||
|  |                     use_gpu=True, | ||||||
|  |                     max_text_length=1024, | ||||||
|  |                     threads=_max_workers | ||||||
|  |                 ) | ||||||
|                 _debug_counter["created"] += 1 |                 _debug_counter["created"] += 1 | ||||||
|                 print(f"OCR engine initialized. Stats: {_debug_counter}") |                 print(f"PaddleOCR engine initialized with {_max_workers} workers. Stats: {_debug_counter}") | ||||||
|             except Exception as e: |             except Exception as e: | ||||||
|                 print(f"OCR model load failed: {e}") |                 print(f"OCR model load failed: {e}") | ||||||
|                 _ref_count -= 1  # 恢复引用计数 |                 _ref_count -= 1 | ||||||
|                 return (False, f"引擎初始化失败: {str(e)}") |                 return (False, f"引擎初始化失败: {str(e)}") | ||||||
|  |  | ||||||
|         # 获取当前引擎引用 |  | ||||||
|         engine = _ocr_engine |         engine = _ocr_engine | ||||||
|  |  | ||||||
|     # 检查引擎是否可用 |     # 检查引擎是否可用 | ||||||
| @ -140,15 +147,56 @@ def detect(frame): | |||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         # 执行OCR检测 |         # 执行OCR检测 | ||||||
|         ocr_res = engine(frame) |         ocr_res = engine.ocr(frame, cls=True) | ||||||
|  |  | ||||||
|         # 验证OCR结果格式 |         # 验证OCR结果格式 | ||||||
|         if not ocr_res or not hasattr(ocr_res, 'txts') or not hasattr(ocr_res, 'scores'): |         if not ocr_res or not isinstance(ocr_res, list): | ||||||
|             return (False, "无OCR结果") |             return (False, "无OCR结果") | ||||||
|  |  | ||||||
|         # 处理OCR结果 |         # 处理OCR结果 - 兼容多种格式 | ||||||
|         texts = [t.strip() for t in ocr_res.txts if t and isinstance(t, str)] |         texts = [] | ||||||
|         confs = [c for c in ocr_res.scores if c and isinstance(c, (int, float))] |         confs = [] | ||||||
|  |         for line in ocr_res: | ||||||
|  |             if line is None: | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             # 处理line可能是列表或直接是文本信息的情况 | ||||||
|  |             if isinstance(line, list): | ||||||
|  |                 items_to_process = line | ||||||
|  |             else: | ||||||
|  |                 items_to_process = [line] | ||||||
|  |  | ||||||
|  |             for item in items_to_process: | ||||||
|  |                 # 跳过纯数字列表(可能是坐标信息) | ||||||
|  |                 if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item): | ||||||
|  |                     continue | ||||||
|  |  | ||||||
|  |                 # 处理元组形式的文本和置信度 (text, confidence) | ||||||
|  |                 if isinstance(item, tuple) and len(item) == 2: | ||||||
|  |                     text, conf = item | ||||||
|  |                     if isinstance(text, str) and isinstance(conf, (int, float)): | ||||||
|  |                         texts.append(text.strip()) | ||||||
|  |                         confs.append(float(conf)) | ||||||
|  |                         continue | ||||||
|  |  | ||||||
|  |                 # 处理列表形式的[坐标信息, (text, confidence)] | ||||||
|  |                 if isinstance(item, list) and len(item) >= 2: | ||||||
|  |                     # 尝试从列表中提取文本和置信度 | ||||||
|  |                     text_data = item[1] | ||||||
|  |                     if isinstance(text_data, tuple) and len(text_data) == 2: | ||||||
|  |                         text, conf = text_data | ||||||
|  |                         if isinstance(text, str) and isinstance(conf, (int, float)): | ||||||
|  |                             texts.append(text.strip()) | ||||||
|  |                             confs.append(float(conf)) | ||||||
|  |                             continue | ||||||
|  |                     elif isinstance(text_data, str): | ||||||
|  |                         # 只有文本没有置信度的情况 | ||||||
|  |                         texts.append(text_data.strip()) | ||||||
|  |                         confs.append(1.0)  # 默认最高置信度 | ||||||
|  |                         continue | ||||||
|  |  | ||||||
|  |                 # 无法识别的格式,记录日志 | ||||||
|  |                 print(f"无法解析的OCR结果格式: {item}") | ||||||
|  |  | ||||||
|         if len(texts) != len(confs): |         if len(texts) != len(confs): | ||||||
|             return (False, "OCR结果格式异常") |             return (False, "OCR结果格式异常") | ||||||
| @ -178,9 +226,16 @@ def detect(frame): | |||||||
|         return (False, f"检测错误: {str(e)}") |         return (False, f"检测错误: {str(e)}") | ||||||
|  |  | ||||||
|     finally: |     finally: | ||||||
|         # 减少引用计数、确保线程安全 |         # 减少引用计数,确保线程安全 | ||||||
|         with _lock: |         with _lock: | ||||||
|             _ref_count = max(0, _ref_count - 1) |             _ref_count = max(0, _ref_count - 1) | ||||||
|             # 持续使用时更新最后使用时间 |  | ||||||
|             if _ref_count > 0: |             if _ref_count > 0: | ||||||
|                 _last_used_time = time.time() |                 _last_used_time = time.time() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def batch_detect(frames): | ||||||
|  |     """批量检测接口,充分利用并行能力""" | ||||||
|  |     results = [] | ||||||
|  |     for frame in frames: | ||||||
|  |         results.append(detect(frame)) | ||||||
|  |     return results | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user