From 9b3d20511ab97638328ad8f69e0635ee7f2febce Mon Sep 17 00:00:00 2001 From: ninghongbin <2409766686@qq.com> Date: Fri, 5 Sep 2025 17:23:50 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9C=80=E6=96=B0=E5=8F=AF=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- core/all.py | 149 +++++++++++++++++++++++++++------- core/face.py | 222 ++++++++++++++++++++++++++++++++++++++++++++++++--- core/ocr.py | 196 +++++++++++++++++++++++++++++++++++---------- 3 files changed, 482 insertions(+), 85 deletions(-) diff --git a/core/all.py b/core/all.py index 9065b80..6b1a717 100644 --- a/core/all.py +++ b/core/all.py @@ -1,45 +1,136 @@ from core.ocr import load_model as ocrLoadModel, detect as ocrDetect from core.face import load_model as faceLoadModel, detect as faceDetect from core.yolo import load_model as yoloLoadModel, detect as yoloDetect +from concurrent.futures import ThreadPoolExecutor, Future +import threading +import cv2 +import numpy as np -# 添加一个标记变量,用于监控load_model是否已被调用 -_model_loaded = False +# -------------------------- 核心配置参数 -------------------------- +MAX_WORKERS = 6 # 线程池最大线程数 +DETECTION_ORDER = ["yolo", "face", "ocr"] # 检测优先级顺序 +TIMEOUT = 30 # 检测超时时间(秒) + +# -------------------------- 全局状态管理 -------------------------- +_executor = None # 线程池实例 +_model_loaded = False # 模型加载状态标记 +_model_lock = threading.Lock() # 模型加载线程锁 +_executor_lock = threading.Lock() # 线程池初始化锁 +_task_counter = 0 # 任务计数器 +_task_counter_lock = threading.Lock() # 任务计数锁 +# -------------------------- 工具函数 -------------------------- +def _get_next_task_id(): + """获取唯一任务ID,用于日志追踪""" + global _task_counter + with _task_counter_lock: + _task_counter += 1 + return _task_counter + + +# -------------------------- 模型加载 -------------------------- def load_model(): + """加载所有检测模型并初始化线程池(仅执行一次)""" global _model_loaded + if not _model_loaded: + with _model_lock: + if not _model_loaded: + print("=== 开始加载检测模型 ===") - # 如果已经调用过,直接忽略 - if _model_loaded: - return + # 按顺序加载模型 + print("加载YOLO模型...") + yoloLoadModel() - # 首次调用时加载模型 - ocrLoadModel() - faceLoadModel() - yoloLoadModel() + print("加载人脸检测模型...") + faceLoadModel() - # 标记为已调用 - _model_loaded = True + print("加载OCR模型...") + ocrLoadModel() + + _model_loaded = True + print("=== 所有模型加载完成 ===") + + # 初始化线程池 + _init_thread_pool() -def detect(frame): - # 先进行YOLO检测 - yolo_flag, yolo_result = yoloDetect(frame) - print("YOLO检测结果:", yolo_result) - if yolo_flag: - return (True, yolo_result, "yolo") +# -------------------------- 线程池管理 -------------------------- +def _init_thread_pool(): + """初始化线程池(仅内部调用)""" + global _executor + with _executor_lock: + if _executor is None: + _executor = ThreadPoolExecutor( + max_workers=MAX_WORKERS, + thread_name_prefix="DetectionThread" + ) + print(f"=== 线程池初始化完成,最大线程数: {MAX_WORKERS} ===") - # YOLO未检测到,进行人脸检测 - face_flag, face_result = faceDetect(frame) - print("人脸检测结果:", face_result) - if face_flag: - return (True, face_result, "face") - # 人脸未检测到,进行OCR检测 - ocr_flag, ocr_result = ocrDetect(frame) - print("OCR检测结果:", ocr_result) - if ocr_flag: - return (True, ocr_result, "ocr") +def shutdown(): + """关闭线程池,释放资源""" + global _executor + with _executor_lock: + if _executor is not None: + _executor.shutdown(wait=True) + _executor = None + print("=== 线程池已安全关闭 ===") + + +# -------------------------- 检测逻辑实现 -------------------------- +def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple: + """在子线程中执行检测逻辑""" + thread_name = threading.current_thread().name + print(f"任务[{task_id}] 开始执行,线程: {thread_name}") + + try: + # 按照优先级执行检测 + for detector in DETECTION_ORDER: + if detector == "yolo": + flag, result = yoloDetect(frame) + elif detector == "face": + flag, result = faceDetect(frame) + elif detector == "ocr": + flag, result = ocrDetect(frame) + else: + flag, result = False, None + + print(f"任务[{task_id}] {detector}检测结果: {'成功' if flag else '失败'}") + if flag: + print(f"任务[{task_id}] 完成检测,使用检测器: {detector}") + return (True, result, detector, task_id) + + # 所有检测器均未检测到结果 + print(f"任务[{task_id}] 所有检测器均未检测到内容") + return (False, "未检测到任何内容", "none", task_id) + + except Exception as e: + print(f"任务[{task_id}] 检测过程发生错误: {str(e)}") + return (False, f"检测错误: {str(e)}", "error", task_id) + + +# -------------------------- 外部调用接口 -------------------------- +def detect(frame: np.ndarray) -> Future: + """ + 提交检测任务到线程池 + + 参数: + frame: 待检测图像(ndarray格式,cv2.imdecode生成) + + 返回: + Future对象,通过result()方法获取检测结果 + """ + # 确保模型已加载 + if not _model_loaded: + print("警告: 模型尚未加载,将自动加载") + load_model() + + # 生成任务ID + task_id = _get_next_task_id() + + # 提交任务到线程池 + future = _executor.submit(_detect_in_thread, frame, task_id) + print(f"任务[{task_id}]: 已提交到线程池") + return future - # 所有检测都未检测到 - return (False, "未检测到任何内容", "none") \ No newline at end of file diff --git a/core/face.py b/core/face.py index e1edee4..c34cc3a 100644 --- a/core/face.py +++ b/core/face.py @@ -1,27 +1,183 @@ import os import numpy as np import cv2 -from PIL import Image # 确保正确导入Image类 +import gc +import time +import threading +from PIL import Image from insightface.app import FaceAnalysis # 导入获取人脸信息的服务 from service.face_service import get_all_face_name_with_eigenvalue +# 用于检查GPU状态 +try: + import pynvml + + pynvml.nvmlInit() + _nvml_available = True +except ImportError: + print("警告: pynvml库未安装,无法检测GPU状态,将默认使用0号GPU") + _nvml_available = False + # 全局变量 _face_app = None _known_faces_embeddings = {} # 存储姓名到特征值的映射 _known_faces_names = [] # 存储所有已知姓名 +_using_gpu = False # 标记是否使用GPU +_used_gpu_id = -1 # 记录当前使用的GPU ID + +# 资源管理变量 +_ref_count = 0 +_last_used_time = 0 +_lock = threading.Lock() +_release_timeout = 8 # 5秒无使用则释放 +_is_releasing = False # 标记是否正在释放 + +# 调试用计数器 +_debug_counter = { + "created": 0, + "released": 0, + "detected": 0 +} -def load_model(): - """加载人脸识别模型及已知人脸特征库""" - global _face_app, _known_faces_embeddings, _known_faces_names +def check_gpu_availability(gpu_id, threshold=0.7): + """检查指定GPU是否可用(内存使用率低于阈值)""" + if not _nvml_available: + return True # 无法检测时默认认为可用 + + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id) + mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + usage = mem_info.used / mem_info.total + # 内存使用率低于阈值则认为可用 + return usage < threshold + except Exception as e: + print(f"检查GPU {gpu_id} 状态时出错: {e}") + return False + + +def select_best_gpu(preferred_gpus=[0, 1]): + """选择最佳可用GPU,严格按照首选列表顺序检查,优先使用0号GPU""" + # 首先检查首选GPU列表 + for gpu_id in preferred_gpus: + try: + # 检查GPU是否存在 + if _nvml_available: + pynvml.nvmlDeviceGetHandleByIndex(gpu_id) + + # 检查GPU是否可用 + if check_gpu_availability(gpu_id): + print(f"GPU {gpu_id} 可用,将使用该GPU") + return gpu_id + else: + if gpu_id == 0: + print(f"GPU 0 内存使用率过高(繁忙),尝试切换到其他GPU") + except Exception as e: + print(f"GPU {gpu_id} 不存在或无法访问: {e}") + continue + + # 如果所有首选GPU都不可用,返回-1表示使用CPU + print("所有指定的GPU都不可用,将使用CPU进行计算") + return -1 + + +def _release_engine(): + """释放人脸识别引擎资源""" + global _face_app, _is_releasing, _known_faces_embeddings, _known_faces_names + if not _face_app or _is_releasing: + return + + try: + _is_releasing = True + # 释放InsightFace资源 + if hasattr(_face_app, 'model'): + # 清除模型资源 + _face_app.model = None + _face_app = None + + # 清空人脸数据 + _known_faces_embeddings.clear() + _known_faces_names.clear() + + _debug_counter["released"] += 1 + print(f"Face recognition engine released. Stats: {_debug_counter}") + + # 清理GPU缓存 + gc.collect() + try: + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + except ImportError: + pass + try: + import tensorflow as tf + tf.keras.backend.clear_session() + except ImportError: + pass + finally: + _is_releasing = False + + +def _monitor_thread(): + """监控线程,检查并释放超时未使用的资源""" + global _ref_count, _last_used_time, _face_app + while True: + time.sleep(5) # 每5秒检查一次 + with _lock: + # 只有当引擎存在、没有引用且超时,才释放 + if _face_app and _ref_count == 0 and not _is_releasing: + elapsed = time.time() - _last_used_time + if elapsed > _release_timeout: + print(f"Idle timeout ({elapsed:.1f}s > {_release_timeout}s), releasing face engine") + _release_engine() + + +def load_model(prefer_gpu=True, preferred_gpus=[0, 1]): + """加载人脸识别模型及已知人脸特征库,默认优先使用0号GPU""" + global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id + + # 确保监控线程只启动一次 + if not any(t.name == "FaceMonitor" for t in threading.enumerate()): + threading.Thread(target=_monitor_thread, daemon=True, name="FaceMonitor").start() + print("Face monitor thread started") + + # 如果正在释放中,等待释放完成 + while _is_releasing: + time.sleep(0.1) + + # 如果已经初始化,直接返回 + if _face_app: + return True # 初始化InsightFace模型 try: - _face_app = FaceAnalysis(name='buffalo_l', root=os.path.expanduser('~/.insightface')) - _face_app.prepare(ctx_id=0, det_size=(640, 640)) + # 初始化InsightFace + print("正在初始化InsightFace人脸识别引擎...") + _face_app = FaceAnalysis(name='buffalo_l', root='~/.insightface') + + # 选择合适的GPU,默认优先使用0号 + ctx_id = 0 + if prefer_gpu: + ctx_id = select_best_gpu(preferred_gpus) + _using_gpu = ctx_id != -1 + _used_gpu_id = ctx_id if _using_gpu else -1 + + if _using_gpu: + print(f"成功初始化,使用GPU {ctx_id} 进行计算") + else: + print("成功初始化,使用CPU进行计算") + + # 准备模型 + _face_app.prepare(ctx_id=ctx_id, det_size=(640, 640)) + print("InsightFace人脸识别引擎初始化成功。") + _debug_counter["created"] += 1 + print(f"Face engine initialized. Stats: {_debug_counter}") + except Exception as e: - print(f"Face model load failed: {e}") + print(f"初始化失败: {e}") return False # 从服务获取所有人脸姓名和特征值 @@ -62,19 +218,52 @@ def load_model(): def detect(frame, threshold=0.4): """检测并识别人脸,返回结果元组(是否匹配到已知人脸, 结果字符串)""" - global _face_app, _known_faces_embeddings, _known_faces_names + global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id + global _ref_count, _last_used_time - if not _face_app or not _known_faces_names or frame is None: - return (False, "未初始化或无效帧") + # 验证前置条件 + if frame is None or frame.size == 0: + return (False, "无效帧数据") + + # 增加引用计数并获取引擎实例 + engine = None + with _lock: + _ref_count += 1 + _last_used_time = time.time() + _debug_counter["detected"] += 1 + + # 初始化引擎(如果未初始化且不在释放中) + if not _face_app and not _is_releasing: + if not load_model(prefer_gpu=True): + _ref_count -= 1 # 恢复引用计数 + return (False, "引擎初始化失败") + + # 获取当前引擎引用 + engine = _face_app + + # 检查引擎是否可用 + if not engine or not _known_faces_names: + with _lock: + _ref_count = max(0, _ref_count - 1) + return (False, "人脸识别引擎不可用或未初始化") try: + # 如果使用GPU,确保输入帧在处理前是连续的数组 + if _using_gpu and not frame.flags.contiguous: + frame = np.ascontiguousarray(frame) + faces = _face_app.get(frame) except Exception as e: print(f"Face detect error: {e}") + # 检测到错误时尝试重新选择GPU并重新初始化 + print("尝试重新选择GPU并重新初始化...") + with _lock: + _ref_count = max(0, _ref_count - 1) + load_model(prefer_gpu=True) # 重新初始化时保持默认GPU优先级 return (False, f"检测错误: {str(e)}") result_parts = [] - has_matched = False # 新增标记:是否有匹配到的已知人脸 + has_matched = False # 标记是否有匹配到的已知人脸 for face in faces: # 特征归一化 @@ -109,5 +298,12 @@ def detect(frame, threshold=0.4): else: result_str = "; ".join(result_parts) - # 第一个返回值改为:是否匹配到已知人脸 - return (has_matched, result_str) + # 减少引用计数,确保线程安全 + with _lock: + _ref_count = max(0, _ref_count - 1) + # 持续使用时更新最后使用时间 + if _ref_count > 0: + _last_used_time = time.time() + + # 第一个返回值为:是否匹配到已知人脸 + return (has_matched, result_str) \ No newline at end of file diff --git a/core/ocr.py b/core/ocr.py index 3b38287..720ef29 100644 --- a/core/ocr.py +++ b/core/ocr.py @@ -1,5 +1,8 @@ import os import cv2 +import gc +import time +import threading from rapidocr import RapidOCR from service.sensitive_service import get_all_sensitive_words @@ -7,70 +10,177 @@ from service.sensitive_service import get_all_sensitive_words _ocr_engine = None _forbidden_words = set() _conf_threshold = 0.5 - ocr_config_path = os.path.join(os.path.dirname(__file__), "config", "config.yaml") +# 资源管理变量 +_ref_count = 0 +_last_used_time = 0 +_lock = threading.Lock() +_release_timeout = 5 # 30秒无使用则释放 +_is_releasing = False # 标记是否正在释放 + +# 调试用计数器 +_debug_counter = { + "created": 0, + "released": 0, + "detected": 0 +} + + +def _release_engine(): + """释放OCR引擎资源""" + global _ocr_engine, _is_releasing + if not _ocr_engine or _is_releasing: + return + + try: + _is_releasing = True + # 如果有释放方法则调用 + if hasattr(_ocr_engine, 'release'): + _ocr_engine.release() + _ocr_engine = None + _debug_counter["released"] += 1 + print(f"OCR engine released. Stats: {_debug_counter}") + + # 清理GPU缓存 + gc.collect() + try: + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + except ImportError: + pass + try: + import tensorflow as tf + tf.keras.backend.clear_session() + except ImportError: + pass + finally: + _is_releasing = False + + +def _monitor_thread(): + """监控线程,优化检查逻辑""" + global _ref_count, _last_used_time, _ocr_engine + while True: + time.sleep(5) # 每5秒检查一次 + with _lock: + # 只有当引擎存在、没有引用且超时,才释放 + if _ocr_engine and _ref_count == 0 and not _is_releasing: + elapsed = time.time() - _last_used_time + if elapsed > _release_timeout: + print(f"Idle timeout ({elapsed:.1f}s > {_release_timeout}s), releasing engine") + _release_engine() + def load_model(): - """加载OCR引擎及违禁词列表""" - global _ocr_engine, _forbidden_words, _conf_threshold + """加载违禁词列表和初始化监控线程""" + global _forbidden_words + + # 确保监控线程只启动一次 + if not any(t.name == "OCRMonitor" for t in threading.enumerate()): + threading.Thread(target=_monitor_thread, daemon=True, name="OCRMonitor").start() + print("OCR monitor thread started") # 加载违禁词 try: _forbidden_words = get_all_sensitive_words() + print(f"Loaded {len(_forbidden_words)} forbidden words") except Exception as e: print(f"Forbidden words load error: {e}") + return False - # 初始化OCR引擎 + # 验证配置文件 if not os.path.exists(ocr_config_path): print(f"OCR config not found: {ocr_config_path}") return False - try: - _ocr_engine = RapidOCR(config_path=ocr_config_path) - except Exception as e: - print(f"OCR model load failed: {e}") - return False - - return True if _ocr_engine else False + return True def detect(frame): - """OCR检测并筛选违禁词,返回(是否检测到违禁词, 结果字符串)""" - if not _ocr_engine or not _forbidden_words or frame is None or frame.size == 0: - return (False, "未初始化或无效帧") + """OCR检测,优化引用计数管理""" + global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time + + # 验证前置条件 + if not _forbidden_words: + return (False, "违禁词未初始化") + if frame is None or frame.size == 0: + return (False, "无效帧数据") + if not os.path.exists(ocr_config_path): + return (False, f"OCR配置文件不存在: {ocr_config_path}") + + # 增加引用计数并获取引擎实例 + engine = None + with _lock: + _ref_count += 1 + _last_used_time = time.time() + _debug_counter["detected"] += 1 + + # 初始化引擎(如果未初始化且不在释放中) + if not _ocr_engine and not _is_releasing: + try: + _ocr_engine = RapidOCR(config_path=ocr_config_path) + _debug_counter["created"] += 1 + print(f"OCR engine initialized. Stats: {_debug_counter}") + except Exception as e: + print(f"OCR model load failed: {e}") + _ref_count -= 1 # 恢复引用计数 + return (False, f"引擎初始化失败: {str(e)}") + + # 获取当前引擎引用 + engine = _ocr_engine + + # 检查引擎是否可用 + if not engine: + with _lock: + _ref_count -= 1 + return (False, "OCR引擎不可用") try: - ocr_res = _ocr_engine(frame) + # 执行OCR检测 + ocr_res = engine(frame) + + # 验证OCR结果格式 + if not ocr_res or not hasattr(ocr_res, 'txts') or not hasattr(ocr_res, 'scores'): + return (False, "无OCR结果") + + # 处理OCR结果 + texts = [t.strip() for t in ocr_res.txts if t and isinstance(t, str)] + confs = [c for c in ocr_res.scores if c and isinstance(c, (int, float))] + + if len(texts) != len(confs): + return (False, "OCR结果格式异常") + + # 筛选违禁词 + vio_info = [] + for txt, conf in zip(texts, confs): + if conf < _conf_threshold: + continue + matched = [w for w in _forbidden_words if w in txt] + if matched: + vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})") + + # 构建结果 + has_text = len(texts) > 0 + has_violation = len(vio_info) > 0 + + if not has_text: + return (False, "未识别到文本") + elif has_violation: + return (True, "; ".join(vio_info)) + else: + return (False, "未检测到违禁词") + except Exception as e: print(f"OCR detect error: {e}") return (False, f"检测错误: {str(e)}") - if not ocr_res or not hasattr(ocr_res, 'txts') or not hasattr(ocr_res, 'scores'): - return (False, "无OCR结果") - - # 处理OCR结果 - texts = [t.strip() for t in ocr_res.txts if t and isinstance(t, str)] - confs = [c for c in ocr_res.scores if c and isinstance(c, (int, float))] - if len(texts) != len(confs): - return (False, "OCR结果格式异常") - - # 筛选违禁词 - vio_info = [] - for txt, conf in zip(texts, confs): - if conf < _conf_threshold: - continue - matched = [w for w in _forbidden_words if w in txt] - if matched: - vio_info.append(f"文本: '{txt}' 包含违禁词: {', '.join(matched)} (置信度: {conf:.2f})") - - # 构建结果字符串 - has_text = len(texts) > 0 - has_violation = len(vio_info) > 0 - - if not has_text: - return (False, "未识别到文本") - elif has_violation: - return (True, "; ".join(vio_info)) - else: - return (False, "未检测到违禁词") \ No newline at end of file + finally: + # 减少引用计数,确保线程安全 + with _lock: + _ref_count = max(0, _ref_count - 1) + # 持续使用时更新最后使用时间 + if _ref_count > 0: + _last_used_time = time.time()