最新可用
This commit is contained in:
		
							
								
								
									
										196
									
								
								core/ocr.py
									
									
									
									
									
								
							
							
						
						
									
										196
									
								
								core/ocr.py
									
									
									
									
									
								
							| @ -1,5 +1,8 @@ | ||||
| import os | ||||
| import cv2 | ||||
| import gc | ||||
| import time | ||||
| import threading | ||||
| from rapidocr import RapidOCR | ||||
| from service.sensitive_service import get_all_sensitive_words | ||||
|  | ||||
| @ -7,70 +10,177 @@ from service.sensitive_service import get_all_sensitive_words | ||||
| _ocr_engine = None | ||||
| _forbidden_words = set() | ||||
| _conf_threshold = 0.5 | ||||
|  | ||||
| ocr_config_path = os.path.join(os.path.dirname(__file__), "config", "config.yaml") | ||||
|  | ||||
| # 资源管理变量 | ||||
| _ref_count = 0 | ||||
| _last_used_time = 0 | ||||
| _lock = threading.Lock() | ||||
| _release_timeout = 5  # 30秒无使用则释放 | ||||
| _is_releasing = False  # 标记是否正在释放 | ||||
|  | ||||
| # 调试用计数器 | ||||
| _debug_counter = { | ||||
|     "created": 0, | ||||
|     "released": 0, | ||||
|     "detected": 0 | ||||
| } | ||||
|  | ||||
|  | ||||
| def _release_engine(): | ||||
|     """释放OCR引擎资源""" | ||||
|     global _ocr_engine, _is_releasing | ||||
|     if not _ocr_engine or _is_releasing: | ||||
|         return | ||||
|  | ||||
|     try: | ||||
|         _is_releasing = True | ||||
|         # 如果有释放方法则调用 | ||||
|         if hasattr(_ocr_engine, 'release'): | ||||
|             _ocr_engine.release() | ||||
|         _ocr_engine = None | ||||
|         _debug_counter["released"] += 1 | ||||
|         print(f"OCR engine released. Stats: {_debug_counter}") | ||||
|  | ||||
|         # 清理GPU缓存 | ||||
|         gc.collect() | ||||
|         try: | ||||
|             import torch | ||||
|             if torch.cuda.is_available(): | ||||
|                 torch.cuda.empty_cache() | ||||
|                 torch.cuda.ipc_collect() | ||||
|         except ImportError: | ||||
|             pass | ||||
|         try: | ||||
|             import tensorflow as tf | ||||
|             tf.keras.backend.clear_session() | ||||
|         except ImportError: | ||||
|             pass | ||||
|     finally: | ||||
|         _is_releasing = False | ||||
|  | ||||
|  | ||||
| def _monitor_thread(): | ||||
|     """监控线程,优化检查逻辑""" | ||||
|     global _ref_count, _last_used_time, _ocr_engine | ||||
|     while True: | ||||
|         time.sleep(5)  # 每5秒检查一次 | ||||
|         with _lock: | ||||
|             # 只有当引擎存在、没有引用且超时,才释放 | ||||
|             if _ocr_engine and _ref_count == 0 and not _is_releasing: | ||||
|                 elapsed = time.time() - _last_used_time | ||||
|                 if elapsed > _release_timeout: | ||||
|                     print(f"Idle timeout ({elapsed:.1f}s > {_release_timeout}s), releasing engine") | ||||
|                     _release_engine() | ||||
|  | ||||
|  | ||||
| def load_model(): | ||||
|     """加载OCR引擎及违禁词列表""" | ||||
|     global _ocr_engine, _forbidden_words, _conf_threshold | ||||
|     """加载违禁词列表和初始化监控线程""" | ||||
|     global _forbidden_words | ||||
|  | ||||
|     # 确保监控线程只启动一次 | ||||
|     if not any(t.name == "OCRMonitor" for t in threading.enumerate()): | ||||
|         threading.Thread(target=_monitor_thread, daemon=True, name="OCRMonitor").start() | ||||
|         print("OCR monitor thread started") | ||||
|  | ||||
|     # 加载违禁词 | ||||
|     try: | ||||
|         _forbidden_words = get_all_sensitive_words() | ||||
|         print(f"Loaded {len(_forbidden_words)} forbidden words") | ||||
|     except Exception as e: | ||||
|         print(f"Forbidden words load error: {e}") | ||||
|         return False | ||||
|  | ||||
|     # 初始化OCR引擎 | ||||
|     # 验证配置文件 | ||||
|     if not os.path.exists(ocr_config_path): | ||||
|         print(f"OCR config not found: {ocr_config_path}") | ||||
|         return False | ||||
|  | ||||
|     try: | ||||
|         _ocr_engine = RapidOCR(config_path=ocr_config_path) | ||||
|     except Exception as e: | ||||
|         print(f"OCR model load failed: {e}") | ||||
|         return False | ||||
|  | ||||
|     return True if _ocr_engine else False | ||||
|     return True | ||||
|  | ||||
|  | ||||
| def detect(frame): | ||||
|     """OCR检测并筛选违禁词,返回(是否检测到违禁词, 结果字符串)""" | ||||
|     if not _ocr_engine or not _forbidden_words or frame is None or frame.size == 0: | ||||
|         return (False, "未初始化或无效帧") | ||||
|     """OCR检测,优化引用计数管理""" | ||||
|     global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time | ||||
|  | ||||
|     # 验证前置条件 | ||||
|     if not _forbidden_words: | ||||
|         return (False, "违禁词未初始化") | ||||
|     if frame is None or frame.size == 0: | ||||
|         return (False, "无效帧数据") | ||||
|     if not os.path.exists(ocr_config_path): | ||||
|         return (False, f"OCR配置文件不存在: {ocr_config_path}") | ||||
|  | ||||
|     # 增加引用计数并获取引擎实例 | ||||
|     engine = None | ||||
|     with _lock: | ||||
|         _ref_count += 1 | ||||
|         _last_used_time = time.time() | ||||
|         _debug_counter["detected"] += 1 | ||||
|  | ||||
|         # 初始化引擎(如果未初始化且不在释放中) | ||||
|         if not _ocr_engine and not _is_releasing: | ||||
|             try: | ||||
|                 _ocr_engine = RapidOCR(config_path=ocr_config_path) | ||||
|                 _debug_counter["created"] += 1 | ||||
|                 print(f"OCR engine initialized. Stats: {_debug_counter}") | ||||
|             except Exception as e: | ||||
|                 print(f"OCR model load failed: {e}") | ||||
|                 _ref_count -= 1  # 恢复引用计数 | ||||
|                 return (False, f"引擎初始化失败: {str(e)}") | ||||
|  | ||||
|         # 获取当前引擎引用 | ||||
|         engine = _ocr_engine | ||||
|  | ||||
|     # 检查引擎是否可用 | ||||
|     if not engine: | ||||
|         with _lock: | ||||
|             _ref_count -= 1 | ||||
|         return (False, "OCR引擎不可用") | ||||
|  | ||||
|     try: | ||||
|         ocr_res = _ocr_engine(frame) | ||||
|         # 执行OCR检测 | ||||
|         ocr_res = engine(frame) | ||||
|  | ||||
|         # 验证OCR结果格式 | ||||
|         if not ocr_res or not hasattr(ocr_res, 'txts') or not hasattr(ocr_res, 'scores'): | ||||
|             return (False, "无OCR结果") | ||||
|  | ||||
|         # 处理OCR结果 | ||||
|         texts = [t.strip() for t in ocr_res.txts if t and isinstance(t, str)] | ||||
|         confs = [c for c in ocr_res.scores if c and isinstance(c, (int, float))] | ||||
|  | ||||
|         if len(texts) != len(confs): | ||||
|             return (False, "OCR结果格式异常") | ||||
|  | ||||
|         # 筛选违禁词 | ||||
|         vio_info = [] | ||||
|         for txt, conf in zip(texts, confs): | ||||
|             if conf < _conf_threshold: | ||||
|                 continue | ||||
|             matched = [w for w in _forbidden_words if w in txt] | ||||
|             if matched: | ||||
|                 vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})") | ||||
|  | ||||
|         # 构建结果 | ||||
|         has_text = len(texts) > 0 | ||||
|         has_violation = len(vio_info) > 0 | ||||
|  | ||||
|         if not has_text: | ||||
|             return (False, "未识别到文本") | ||||
|         elif has_violation: | ||||
|             return (True, "; ".join(vio_info)) | ||||
|         else: | ||||
|             return (False, "未检测到违禁词") | ||||
|  | ||||
|     except Exception as e: | ||||
|         print(f"OCR detect error: {e}") | ||||
|         return (False, f"检测错误: {str(e)}") | ||||
|  | ||||
|     if not ocr_res or not hasattr(ocr_res, 'txts') or not hasattr(ocr_res, 'scores'): | ||||
|         return (False, "无OCR结果") | ||||
|  | ||||
|     # 处理OCR结果 | ||||
|     texts = [t.strip() for t in ocr_res.txts if t and isinstance(t, str)] | ||||
|     confs = [c for c in ocr_res.scores if c and isinstance(c, (int, float))] | ||||
|     if len(texts) != len(confs): | ||||
|         return (False, "OCR结果格式异常") | ||||
|  | ||||
|     # 筛选违禁词 | ||||
|     vio_info = [] | ||||
|     for txt, conf in zip(texts, confs): | ||||
|         if conf < _conf_threshold: | ||||
|             continue | ||||
|         matched = [w for w in _forbidden_words if w in txt] | ||||
|         if matched: | ||||
|             vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})") | ||||
|  | ||||
|     # 构建结果字符串 | ||||
|     has_text = len(texts) > 0 | ||||
|     has_violation = len(vio_info) > 0 | ||||
|  | ||||
|     if not has_text: | ||||
|         return (False, "未识别到文本") | ||||
|     elif has_violation: | ||||
|         return (True, "; ".join(vio_info)) | ||||
|     else: | ||||
|         return (False, "未检测到违禁词") | ||||
|     finally: | ||||
|         # 减少引用计数,确保线程安全 | ||||
|         with _lock: | ||||
|             _ref_count = max(0, _ref_count - 1) | ||||
|             # 持续使用时更新最后使用时间 | ||||
|             if _ref_count > 0: | ||||
|                 _last_used_time = time.time() | ||||
|  | ||||
		Reference in New Issue
	
	Block a user