识别结果保存到对应目录下
This commit is contained in:
		
							
								
								
									
										175
									
								
								core/all.py
									
									
									
									
									
								
							
							
						
						
									
										175
									
								
								core/all.py
									
									
									
									
									
								
							| @ -1,139 +1,70 @@ | |||||||
|  | import cv2 | ||||||
| from core.ocr import load_model as ocrLoadModel, detect as ocrDetect | from core.ocr import load_model as ocrLoadModel, detect as ocrDetect | ||||||
| from core.face import load_model as faceLoadModel, detect as faceDetect | from core.face import load_model as faceLoadModel, detect as faceDetect | ||||||
| from core.yolo import load_model as yoloLoadModel, detect as yoloDetect | from core.yolo import load_model as yoloLoadModel, detect as yoloDetect | ||||||
| from concurrent.futures import ThreadPoolExecutor, Future | # 导入保存路径函数(根据实际文件位置调整导入路径) | ||||||
| import threading | from core.establish import get_image_save_path | ||||||
| import cv2 | # 模型加载状态标记(避免重复加载) | ||||||
| import numpy as np |  | ||||||
|  |  | ||||||
| # -------------------------- 核心配置参数 -------------------------- |  | ||||||
| MAX_WORKERS = 6  # 线程池最大线程数 |  | ||||||
| DETECTION_ORDER = ["yolo", "face", "ocr"]  # 检测执行顺序 |  | ||||||
| TIMEOUT = 30  # 检测超时时间(秒) 【确保此常量可被外部导入】 |  | ||||||
|  |  | ||||||
| # -------------------------- 全局状态管理 -------------------------- |  | ||||||
| _executor = None  # 线程池实例 |  | ||||||
| _model_loaded = False  # 模型加载状态标记 |  | ||||||
| _model_lock = threading.Lock()  # 模型加载线程锁 |  | ||||||
| _executor_lock = threading.Lock()  # 线程池初始化锁 |  | ||||||
| _task_counter = 0  # 任务计数器 |  | ||||||
| _task_counter_lock = threading.Lock()  # 任务计数锁 |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # -------------------------- 工具函数 -------------------------- | _model_loaded = False | ||||||
| def _get_next_task_id(): |  | ||||||
|     """获取唯一任务ID、用于日志追踪""" |  | ||||||
|     global _task_counter |  | ||||||
|     with _task_counter_lock: |  | ||||||
|         _task_counter += 1 |  | ||||||
|         return _task_counter |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # -------------------------- 模型加载 -------------------------- |  | ||||||
| def load_model(): | def load_model(): | ||||||
|     """加载所有检测模型并初始化线程池(仅执行一次)""" |     """加载所有检测模型(仅首次调用时执行)""" | ||||||
|     global _model_loaded |     global _model_loaded | ||||||
|     if not _model_loaded: |     if _model_loaded: | ||||||
|         with _model_lock: |         print("模型已加载,无需重复执行") | ||||||
|             if not _model_loaded: |         return | ||||||
|                 print("=== 开始加载检测模型 ===") |  | ||||||
|  |  | ||||||
|                 # 按顺序加载模型 |     # 依次加载OCR、人脸、YOLO模型 | ||||||
|                 print("加载YOLO模型...") |     ocrLoadModel() | ||||||
|                 yoloLoadModel() |     faceLoadModel() | ||||||
|  |     yoloLoadModel() | ||||||
|  |  | ||||||
|                 print("加载人脸检测模型...") |     _model_loaded = True | ||||||
|                 faceLoadModel() |     print("所有检测模型加载完成") | ||||||
|  |  | ||||||
|                 print("加载OCR模型...") |  | ||||||
|                 ocrLoadModel() |  | ||||||
|  |  | ||||||
|                 _model_loaded = True |  | ||||||
|                 print("=== 所有模型加载完成 ===") |  | ||||||
|  |  | ||||||
|                 # 初始化线程池 |  | ||||||
|                 _init_thread_pool() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # -------------------------- 线程池管理 -------------------------- | def detect(frame): | ||||||
| def _init_thread_pool(): |  | ||||||
|     """初始化线程池(仅内部调用)""" |  | ||||||
|     global _executor |  | ||||||
|     with _executor_lock: |  | ||||||
|         if _executor is None: |  | ||||||
|             _executor = ThreadPoolExecutor( |  | ||||||
|                 max_workers=MAX_WORKERS, |  | ||||||
|                 thread_name_prefix="DetectionThread" |  | ||||||
|             ) |  | ||||||
|             print(f"=== 线程池初始化完成、最大线程数: {MAX_WORKERS} ===") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def shutdown(): |  | ||||||
|     """关闭线程池、释放资源""" |  | ||||||
|     global _executor |  | ||||||
|     with _executor_lock: |  | ||||||
|         if _executor is not None: |  | ||||||
|             _executor.shutdown(wait=True) |  | ||||||
|             _executor = None |  | ||||||
|             print("=== 线程池已安全关闭 ===") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # -------------------------- 检测逻辑实现 -------------------------- |  | ||||||
| def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple: |  | ||||||
|     """在子线程中执行检测逻辑(返回4元素tuple:检测是否成功、结果数据、检测器类型、任务ID)""" |  | ||||||
|     thread_name = threading.current_thread().name |  | ||||||
|     print(f"任务[{task_id}] 开始执行、线程: {thread_name}") |  | ||||||
|  |  | ||||||
|     try: |  | ||||||
|         # 按照配置顺序执行检测 |  | ||||||
|         for detector in DETECTION_ORDER: |  | ||||||
|             if detector == "yolo": |  | ||||||
|                 success, result = yoloDetect(frame) |  | ||||||
|             elif detector == "face": |  | ||||||
|                 success, result = faceDetect(frame) |  | ||||||
|             elif detector == "ocr": |  | ||||||
|                 success, result = ocrDetect(frame) |  | ||||||
|             else: |  | ||||||
|                 success, result = False, None |  | ||||||
|  |  | ||||||
|             print(f"任务[{task_id}] {detector}检测状态: {'成功' if success else '未检测到内容'}") |  | ||||||
|             if success: |  | ||||||
|                 print(f"任务[{task_id}] 完成检测、使用检测器: {detector}") |  | ||||||
|                 return (success, result, detector, task_id)  # 4元素tuple |  | ||||||
|  |  | ||||||
|         # 所有检测器均未检测到结果 |  | ||||||
|         print(f"任务[{task_id}] 所有检测器均未检测到有效内容") |  | ||||||
|         return (False, "未检测到任何有效内容", "none", task_id)  # 4元素tuple |  | ||||||
|  |  | ||||||
|     except Exception as e: |  | ||||||
|         print(f"任务[{task_id}] 检测过程发生错误: {str(e)}") |  | ||||||
|         return (False, f"检测错误: {str(e)}", "error", task_id)  # 4元素tuple |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # -------------------------- 外部调用接口 -------------------------- |  | ||||||
| def detect(frame: np.ndarray) -> Future: |  | ||||||
|     """ |     """ | ||||||
|     提交检测任务到线程池(返回Future对象,需调用result()获取4元素结果) |     执行模型检测,检测到违规时按指定格式保存图片 | ||||||
|  |     参数: | ||||||
|     参数: |         frame: 待检测的图像帧(OpenCV格式,numpy.ndarray类型) | ||||||
|         frame: 待检测图像(ndarray格式、cv2.imdecode生成) |     返回: | ||||||
|  |         (检测结果布尔值, 检测详情, 检测模型类型) | ||||||
|     返回: |  | ||||||
|         Future对象、result()返回tuple: (success, data, detector_type, task_id) |  | ||||||
|             success: 布尔值,表示是否检测到有效内容 |  | ||||||
|             data: 检测结果数据(成功时为具体结果,失败时为错误信息) |  | ||||||
|             detector_type: 使用的检测器类型("yolo"/"face"/"ocr"/"none"/"error") |  | ||||||
|             task_id: 任务唯一标识 |  | ||||||
|     """ |     """ | ||||||
|     # 确保模型已加载 |     # 1. YOLO检测(优先级1) | ||||||
|     if not _model_loaded: |     yolo_flag, yolo_result = yoloDetect(frame) | ||||||
|         print("警告: 模型尚未加载、将自动加载") |     print(f"YOLO检测结果:{yolo_result}") | ||||||
|         load_model() |     if yolo_flag: | ||||||
|  |         # 直接调用路径生成函数,无需传入原始图片名 | ||||||
|  |         save_path = get_image_save_path(model_type="yolo") | ||||||
|  |         if save_path: | ||||||
|  |             cv2.imwrite(save_path, frame) | ||||||
|  |             print(f"✅ YOLO违规图片已保存:{save_path}") | ||||||
|  |         return (True, yolo_result, "yolo") | ||||||
|  |  | ||||||
|     # 生成任务ID |     # 2. 人脸检测(优先级2) | ||||||
|     task_id = _get_next_task_id() |     face_flag, face_result = faceDetect(frame) | ||||||
|  |     print(f"人脸检测结果:{face_result}") | ||||||
|  |     if face_flag: | ||||||
|  |         save_path = get_image_save_path(model_type="face") | ||||||
|  |         if save_path: | ||||||
|  |             cv2.imwrite(save_path, frame) | ||||||
|  |             print(f"✅ 人脸违规图片已保存:{save_path}") | ||||||
|  |         return (True, face_result, "face") | ||||||
|  |  | ||||||
|     # 提交任务到线程池(返回Future) |     # 3. OCR检测(优先级3) | ||||||
|     future = _executor.submit(_detect_in_thread, frame, task_id) |     ocr_flag, ocr_result = ocrDetect(frame) | ||||||
|     print(f"任务[{task_id}]: 已提交到线程池") |     print(f"OCR检测结果:{ocr_result}") | ||||||
|     return future |     if ocr_flag: | ||||||
|  |         save_path = get_image_save_path(model_type="ocr") | ||||||
|  |         if save_path: | ||||||
|  |             cv2.imwrite(save_path, frame) | ||||||
|  |             print(f"✅ OCR违规图片已保存:{save_path}") | ||||||
|  |         return (True, ocr_result, "ocr") | ||||||
|  |  | ||||||
|  |     # 4. 无违规内容(不保存图片) | ||||||
|  |     print(f"❌ 未检测到任何违规内容,不保存图片") | ||||||
|  |     return (False, "未检测到任何内容", "none") | ||||||
							
								
								
									
										111
									
								
								core/establish.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								core/establish.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,111 @@ | |||||||
|  | import os | ||||||
|  | import datetime | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 配置IP文件路径(统一使用绝对路径) | ||||||
|  | IP_FILE_PATH = Path(r"D:\ccc\IP.txt") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def create_directory_structure(): | ||||||
|  |     """创建项目所需的目录结构""" | ||||||
|  |     try: | ||||||
|  |         # 1. 创建根目录下的resource文件夹 | ||||||
|  |         resource_dir = Path("resource") | ||||||
|  |         resource_dir.mkdir(exist_ok=True) | ||||||
|  |         print(f"确保resource目录存在: {resource_dir.absolute()}") | ||||||
|  |  | ||||||
|  |         # 2. 在resource下创建dect文件夹 | ||||||
|  |         dect_dir = resource_dir / "dect" | ||||||
|  |         dect_dir.mkdir(exist_ok=True) | ||||||
|  |         print(f"确保dect目录存在: {dect_dir.absolute()}") | ||||||
|  |  | ||||||
|  |         # 3. 在dect下创建三个模型文件夹 | ||||||
|  |         model_dirs = ["ocr", "face", "yolo"] | ||||||
|  |         for model in model_dirs: | ||||||
|  |             model_dir = dect_dir / model | ||||||
|  |             model_dir.mkdir(exist_ok=True) | ||||||
|  |             print(f"确保{model}模型目录存在: {model_dir.absolute()}") | ||||||
|  |  | ||||||
|  |         # 4. 读取ip.txt文件获取IP地址 | ||||||
|  |         try: | ||||||
|  |             with open(IP_FILE_PATH, "r") as f: | ||||||
|  |                 ip_addresses = [line.strip() for line in f if line.strip()] | ||||||
|  |  | ||||||
|  |             if not ip_addresses: | ||||||
|  |                 print("警告: ip.txt文件中未找到有效的IP地址") | ||||||
|  |                 return | ||||||
|  |  | ||||||
|  |             print(f"从ip.txt中读取到的IP地址: {ip_addresses}") | ||||||
|  |  | ||||||
|  |             # 5. 获取当前日期 | ||||||
|  |             now = datetime.datetime.now() | ||||||
|  |             current_year = str(now.year) | ||||||
|  |             current_month = str(now.month) | ||||||
|  |  | ||||||
|  |             # 6. 为每个IP在每个模型文件夹下创建年->月的目录结构 | ||||||
|  |             for ip in ip_addresses: | ||||||
|  |                 # 处理IP地址中的特殊字符(如果有) | ||||||
|  |                 safe_ip = ip.replace(".", "_") | ||||||
|  |  | ||||||
|  |                 for model in model_dirs: | ||||||
|  |                     # 构建路径: resource/dect/{model}/{ip}/{year}/{month} | ||||||
|  |                     ip_dir = dect_dir / model / safe_ip | ||||||
|  |                     year_dir = ip_dir / current_year | ||||||
|  |                     month_dir = year_dir / current_month | ||||||
|  |  | ||||||
|  |                     # 创建目录(如果不存在) | ||||||
|  |                     month_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |                     print(f"创建/确保目录存在: {month_dir.absolute()}") | ||||||
|  |  | ||||||
|  |         except FileNotFoundError: | ||||||
|  |             print(f"错误: 未找到ip.txt文件,请确保该文件存在于 {IP_FILE_PATH}") | ||||||
|  |         except Exception as e: | ||||||
|  |             print(f"处理IP和日期目录时发生错误: {str(e)}") | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"创建目录结构时发生错误: {str(e)}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_image_save_path(model_type: str) -> str: | ||||||
|  |     """ | ||||||
|  |     获取图片保存的完整路径(不依赖原始图片名称) | ||||||
|  |  | ||||||
|  |     参数: | ||||||
|  |         model_type: 模型类型,应为"ocr"、"face"或"yolo" | ||||||
|  |  | ||||||
|  |     返回: | ||||||
|  |         完整的图片保存路径 | ||||||
|  |     """ | ||||||
|  |     try: | ||||||
|  |         # 读取IP地址(假设只有一个IP或使用第一个IP) | ||||||
|  |         with open(IP_FILE_PATH, "r") as f: | ||||||
|  |             ip_addresses = [line.strip() for line in f if line.strip()] | ||||||
|  |  | ||||||
|  |         if not ip_addresses: | ||||||
|  |             raise ValueError("ip.txt文件中未找到有效的IP地址") | ||||||
|  |  | ||||||
|  |         ip = ip_addresses[0] | ||||||
|  |         safe_ip = ip.replace(".", "_") | ||||||
|  |  | ||||||
|  |         # 获取当前日期和时间(精确到毫秒,确保文件名唯一) | ||||||
|  |         now = datetime.datetime.now() | ||||||
|  |         current_year = str(now.year) | ||||||
|  |         current_month = str(now.month) | ||||||
|  |         current_day = str(now.day) | ||||||
|  |         # 生成时间戳字符串(格式:年月日时分秒毫秒) | ||||||
|  |         timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3]  # 去除最后三位,保留到毫秒 | ||||||
|  |  | ||||||
|  |         # 构建路径: resource/dect/{model}/{ip}/{year}/{month}/{day} | ||||||
|  |         day_dir = Path("resource") / "dect" / model_type / safe_ip / current_year / current_month / current_day | ||||||
|  |         day_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |  | ||||||
|  |         # 构建图片文件名(使用时间戳确保唯一性) | ||||||
|  |         image_filename = f"resource_dect_{model_type}_{safe_ip}_{current_year}_{current_month}_{current_day}_{timestamp}.jpg" | ||||||
|  |         image_path = day_dir / image_filename | ||||||
|  |  | ||||||
|  |         return str(image_path) | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"获取图片保存路径时发生错误: {str(e)}") | ||||||
|  |         return "" | ||||||
							
								
								
									
										305
									
								
								core/face.py
									
									
									
									
									
								
							
							
						
						
									
										305
									
								
								core/face.py
									
									
									
									
									
								
							| @ -6,203 +6,217 @@ import time | |||||||
| import threading | import threading | ||||||
| from PIL import Image | from PIL import Image | ||||||
| from insightface.app import FaceAnalysis | from insightface.app import FaceAnalysis | ||||||
| # 导入获取人脸信息的服务 | # 假设service.face_service中get_all_face_name_with_eigenvalue可获取人脸数据 | ||||||
| from service.face_service import get_all_face_name_with_eigenvalue | from service.face_service import get_all_face_name_with_eigenvalue | ||||||
|  |  | ||||||
| # 用于检查GPU状态 | # GPU状态检查支持 | ||||||
| try: | try: | ||||||
|     import pynvml |     import pynvml | ||||||
|  |  | ||||||
|     pynvml.nvmlInit() |     pynvml.nvmlInit() | ||||||
|     _nvml_available = True |     _nvml_available = True | ||||||
| except ImportError: | except ImportError: | ||||||
|     print("警告: pynvml库未安装、无法检测GPU状态、将默认使用0号GPU") |     print("警告: pynvml库未安装,无法检测GPU状态,默认尝试使用GPU") | ||||||
|     _nvml_available = False |     _nvml_available = False | ||||||
|  |  | ||||||
| # 全局变量 | # 全局人脸引擎与特征库 | ||||||
| _face_app = None | _face_app = None | ||||||
| _known_faces_embeddings = {}  # 存储姓名到特征值的映射 | _known_faces_embeddings = {}  # 姓名 -> 归一化特征值的映射 | ||||||
| _known_faces_names = []  # 存储所有已知姓名 | _known_faces_names = []  # 已知人脸姓名列表 | ||||||
| _using_gpu = False  # 标记是否使用GPU |  | ||||||
| _used_gpu_id = -1  # 记录当前使用的GPU ID | # GPU使用状态标记 | ||||||
|  | _using_gpu = False  # 是否使用GPU | ||||||
|  | _used_gpu_id = -1  # 使用的GPU ID(-1表示CPU) | ||||||
|  |  | ||||||
| # 资源管理变量 | # 资源管理变量 | ||||||
| _ref_count = 0 | _ref_count = 0  # 引擎引用计数(记录当前使用次数) | ||||||
| _last_used_time = 0 | _last_used_time = 0  # 最后一次使用引擎的时间 | ||||||
| _lock = threading.Lock() | _lock = threading.Lock()  # 线程安全锁 | ||||||
| _release_timeout = 8  # 5秒无使用则释放 | _release_timeout = 8  # 闲置超时时间(秒) | ||||||
| _is_releasing = False  # 标记是否正在释放 | _is_releasing = False  # 资源释放中标记 | ||||||
|  | _monitor_thread_running = False  # 监控线程运行标记 | ||||||
|  |  | ||||||
| # 调试用计数器 | # 调试计数器 | ||||||
| _debug_counter = { | _debug_counter = { | ||||||
|     "created": 0, |     "engine_created": 0,  # 引擎创建次数 | ||||||
|     "released": 0, |     "engine_released": 0,  # 引擎释放次数 | ||||||
|     "detected": 0 |     "detection_calls": 0  # 检测函数调用次数 | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| def check_gpu_availability(gpu_id, threshold=0.7): | def check_gpu_availability(gpu_id, memory_threshold=0.7): | ||||||
|     """检查指定GPU是否可用(内存使用率低于阈值)""" |     """检查指定GPU的内存使用率是否低于阈值(判定为“可用”)""" | ||||||
|     if not _nvml_available: |     if not _nvml_available: | ||||||
|         return True  # 无法检测时默认认为可用 |         return True | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id) |         handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id) | ||||||
|         mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) |         mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) | ||||||
|         usage = mem_info.used / mem_info.total |         memory_usage = mem_info.used / mem_info.total | ||||||
|         # 内存使用率低于阈值则认为可用 |         return memory_usage < memory_threshold | ||||||
|         return usage < threshold |  | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"检查GPU {gpu_id} 状态时出错: {e}") |         print(f"检查GPU {gpu_id} 状态失败: {e}") | ||||||
|         return False |         return False | ||||||
|  |  | ||||||
|  |  | ||||||
| def select_best_gpu(preferred_gpus=[0, 1]): | def select_best_gpu(preferred_gpus=[0, 1]): | ||||||
|     """选择最佳可用GPU、严格按照首选列表顺序检查、优先使用0号GPU""" |     """按优先级选择可用GPU,优先0号;均不可用则返回-1(CPU)""" | ||||||
|     # 首先检查首选GPU列表 |  | ||||||
|     for gpu_id in preferred_gpus: |     for gpu_id in preferred_gpus: | ||||||
|         try: |         try: | ||||||
|             # 检查GPU是否存在 |             # 验证GPU是否存在 | ||||||
|             if _nvml_available: |             if _nvml_available: | ||||||
|                 pynvml.nvmlDeviceGetHandleByIndex(gpu_id) |                 pynvml.nvmlDeviceGetHandleByIndex(gpu_id) | ||||||
|  |             # 验证GPU内存是否充足 | ||||||
|             # 检查GPU是否可用 |  | ||||||
|             if check_gpu_availability(gpu_id): |             if check_gpu_availability(gpu_id): | ||||||
|                 print(f"GPU {gpu_id} 可用、将使用该GPU") |                 print(f"GPU {gpu_id} 可用,将使用该GPU") | ||||||
|                 return gpu_id |                 return gpu_id | ||||||
|             else: |             else: | ||||||
|                 if gpu_id == 0: |                 if gpu_id == 0: | ||||||
|                     print(f"GPU 0 内存使用率过高(繁忙)、尝试切换到其他GPU") |                     print("GPU 0 内存使用率过高,尝试其他GPU") | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             print(f"GPU {gpu_id} 不存在或无法访问: {e}") |             print(f"GPU {gpu_id} 不可用或访问失败: {e}") | ||||||
|             continue |     print("所有指定GPU均不可用,将使用CPU计算") | ||||||
|  |  | ||||||
|     # 如果所有首选GPU都不可用、返回-1表示使用CPU |  | ||||||
|     print("所有指定的GPU都不可用、将使用CPU进行计算") |  | ||||||
|     return -1 |     return -1 | ||||||
|  |  | ||||||
|  |  | ||||||
| def _release_engine(): | def _release_engine_resources(): | ||||||
|     """释放人脸识别引擎资源""" |     """释放人脸引擎的所有资源(模型、特征库、GPU缓存等)""" | ||||||
|     global _face_app, _is_releasing, _known_faces_embeddings, _known_faces_names |     global _face_app, _is_releasing, _known_faces_embeddings, _known_faces_names | ||||||
|     if not _face_app or _is_releasing: |     if not _face_app or _is_releasing: | ||||||
|         return |         return | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         _is_releasing = True |         _is_releasing = True | ||||||
|         # 释放InsightFace资源 |         print("开始释放人脸引擎资源...") | ||||||
|         if hasattr(_face_app, 'model'): |  | ||||||
|             # 清除模型资源 |  | ||||||
|             _face_app.model = None |  | ||||||
|         _face_app = None |  | ||||||
|  |  | ||||||
|         # 清空人脸数据 |         # 释放InsightFace模型资源 | ||||||
|  |         if hasattr(_face_app, "model"): | ||||||
|  |             _face_app.model = None  # 显式置空模型引用 | ||||||
|  |         _face_app = None  # 释放引擎实例 | ||||||
|  |  | ||||||
|  |         # 清空人脸特征库 | ||||||
|         _known_faces_embeddings.clear() |         _known_faces_embeddings.clear() | ||||||
|         _known_faces_names.clear() |         _known_faces_names.clear() | ||||||
|  |  | ||||||
|         _debug_counter["released"] += 1 |         _debug_counter["engine_released"] += 1 | ||||||
|         print(f"Face recognition engine released. Stats: {_debug_counter}") |         print(f"人脸引擎已释放,调试统计: {_debug_counter}") | ||||||
|  |  | ||||||
|         # 清理GPU缓存 |         # 强制垃圾回收 | ||||||
|         gc.collect() |         gc.collect() | ||||||
|  |  | ||||||
|  |         # 清理各深度学习框架的GPU缓存 | ||||||
|  |         # Torch 缓存清理 | ||||||
|         try: |         try: | ||||||
|             import torch |             import torch | ||||||
|             if torch.cuda.is_available(): |             if torch.cuda.is_available(): | ||||||
|                 torch.cuda.empty_cache() |                 torch.cuda.empty_cache() | ||||||
|                 torch.cuda.ipc_collect() |                 torch.cuda.ipc_collect() | ||||||
|  |                 print("Torch GPU缓存已清理") | ||||||
|         except ImportError: |         except ImportError: | ||||||
|             pass |             pass | ||||||
|  |  | ||||||
|  |         # TensorFlow 缓存清理 | ||||||
|         try: |         try: | ||||||
|             import tensorflow as tf |             import tensorflow as tf | ||||||
|             tf.keras.backend.clear_session() |             tf.keras.backend.clear_session() | ||||||
|  |             print("TensorFlow会话已清理") | ||||||
|         except ImportError: |         except ImportError: | ||||||
|             pass |             pass | ||||||
|  |  | ||||||
|  |         # MXNet 缓存清理(InsightFace底层常用MXNet) | ||||||
|  |         try: | ||||||
|  |             import mxnet as mx | ||||||
|  |             mx.nd.waitall()  # 等待所有计算完成并释放资源 | ||||||
|  |             print("MXNet资源已等待释放") | ||||||
|  |         except ImportError: | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"释放资源过程中出错: {e}") | ||||||
|     finally: |     finally: | ||||||
|         _is_releasing = False |         _is_releasing = False | ||||||
|  |  | ||||||
|  |  | ||||||
| def _monitor_thread(): | def _resource_monitor_thread(): | ||||||
|     """监控线程、检查并释放超时未使用的资源""" |     """后台监控线程:检测引擎闲置超时,触发资源释放""" | ||||||
|     global _ref_count, _last_used_time, _face_app |     global _ref_count, _last_used_time, _face_app, _monitor_thread_running | ||||||
|     while True: |     _monitor_thread_running = True | ||||||
|         time.sleep(5)  # 每5秒检查一次 |     while _monitor_thread_running: | ||||||
|  |         time.sleep(2)  # 缩短检查间隔,加快闲置检测响应 | ||||||
|         with _lock: |         with _lock: | ||||||
|             # 只有当引擎存在、没有引用且超时、才释放 |             # 当“引擎存在 + 无引用 + 未在释放中”时,检查闲置时间 | ||||||
|             if _face_app and _ref_count == 0 and not _is_releasing: |             if _face_app and _ref_count == 0 and not _is_releasing: | ||||||
|                 elapsed = time.time() - _last_used_time |                 idle_time = time.time() - _last_used_time | ||||||
|                 if elapsed > _release_timeout: |                 if idle_time > _release_timeout: | ||||||
|                     print(f"Idle timeout ({elapsed:.1f}s > {_release_timeout}s), releasing face engine") |                     print(f"引擎闲置超时({idle_time:.1f}s > {_release_timeout}s),释放资源") | ||||||
|                     _release_engine() |                     _release_engine_resources() | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_model(prefer_gpu=True, preferred_gpus=[0, 1]): | def load_model(prefer_gpu=True, preferred_gpus=[0, 1]): | ||||||
|     """加载人脸识别模型及已知人脸特征库、默认优先使用0号GPU""" |     """加载人脸识别引擎及已知人脸特征库(默认优先用0号GPU)""" | ||||||
|     global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id |     global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id | ||||||
|  |  | ||||||
|     # 确保监控线程只启动一次 |     # 启动后台监控线程(确保仅启动一次) | ||||||
|     if not any(t.name == "FaceMonitor" for t in threading.enumerate()): |     if not _monitor_thread_running: | ||||||
|         threading.Thread(target=_monitor_thread, daemon=True, name="FaceMonitor").start() |         threading.Thread( | ||||||
|         print("Face monitor thread started") |             target=_resource_monitor_thread, | ||||||
|  |             daemon=True, | ||||||
|  |             name="FaceEngineMonitor" | ||||||
|  |         ).start() | ||||||
|  |         print("人脸引擎监控线程已启动") | ||||||
|  |  | ||||||
|     # 如果正在释放中、等待释放完成 |     # 若正在释放资源,等待释放完成 | ||||||
|     while _is_releasing: |     while _is_releasing: | ||||||
|         time.sleep(0.1) |         time.sleep(0.1) | ||||||
|  |  | ||||||
|     # 如果已经初始化、直接返回 |     # 若引擎已初始化,直接返回 | ||||||
|     if _face_app: |     if _face_app: | ||||||
|         return True |         return True | ||||||
|  |  | ||||||
|     # 初始化InsightFace模型 |     # 初始化InsightFace引擎 | ||||||
|     try: |     try: | ||||||
|         # 初始化InsightFace |  | ||||||
|         print("正在初始化InsightFace人脸识别引擎...") |         print("正在初始化InsightFace人脸识别引擎...") | ||||||
|         _face_app = FaceAnalysis(name='buffalo_l', root='~/.insightface') |         _face_app = FaceAnalysis(name="buffalo_l", root=os.path.expanduser("~/.insightface")) | ||||||
|  |  | ||||||
|         # 选择合适的GPU、默认优先使用0号 |         # 选择GPU(优先用0号) | ||||||
|         ctx_id = 0 |         ctx_id = 0 | ||||||
|         if prefer_gpu: |         if prefer_gpu: | ||||||
|             ctx_id = select_best_gpu(preferred_gpus) |             ctx_id = select_best_gpu(preferred_gpus) | ||||||
|             _using_gpu = ctx_id != -1 |             _using_gpu = ctx_id != -1 | ||||||
|             _used_gpu_id = ctx_id if _using_gpu else -1 |             _used_gpu_id = ctx_id if _using_gpu else -1 | ||||||
|  |  | ||||||
|             if _using_gpu: |         if _using_gpu: | ||||||
|                 print(f"成功初始化、使用GPU {ctx_id} 进行计算") |             print(f"引擎初始化成功,将使用GPU {ctx_id} 计算") | ||||||
|             else: |         else: | ||||||
|                 print("成功初始化、使用CPU进行计算") |             print("引擎初始化成功,将使用CPU计算") | ||||||
|  |  | ||||||
|         # 准备模型 |         # 准备模型(加载到指定设备) | ||||||
|         _face_app.prepare(ctx_id=ctx_id, det_size=(640, 640)) |         _face_app.prepare(ctx_id=ctx_id, det_size=(640, 640)) | ||||||
|         print("InsightFace人脸识别引擎初始化成功。") |         print("InsightFace引擎初始化完成") | ||||||
|         _debug_counter["created"] += 1 |         _debug_counter["engine_created"] += 1 | ||||||
|         print(f"Face engine initialized. Stats: {_debug_counter}") |         print(f"引擎调试统计: {_debug_counter}") | ||||||
|  |  | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"初始化失败: {e}") |         print(f"引擎初始化失败: {e}") | ||||||
|         return False |         return False | ||||||
|  |  | ||||||
|     # 从服务获取所有人脸姓名和特征值 |     # 从服务加载已知人脸的姓名和特征值 | ||||||
|     try: |     try: | ||||||
|         face_data = get_all_face_name_with_eigenvalue() |         face_data = get_all_face_name_with_eigenvalue() | ||||||
|  |  | ||||||
|         # 处理获取到的人脸数据 |  | ||||||
|         for person_name, eigenvalue_data in face_data.items(): |         for person_name, eigenvalue_data in face_data.items(): | ||||||
|             # 处理特征值数据 - 兼容数组和字符串两种格式 |             # 兼容“numpy数组”和“字符串”格式的特征值 | ||||||
|             if isinstance(eigenvalue_data, np.ndarray): |             if isinstance(eigenvalue_data, np.ndarray): | ||||||
|                 # 如果已经是numpy数组、直接使用 |  | ||||||
|                 eigenvalue = eigenvalue_data.astype(np.float32) |                 eigenvalue = eigenvalue_data.astype(np.float32) | ||||||
|             elif isinstance(eigenvalue_data, str): |             elif isinstance(eigenvalue_data, str): | ||||||
|                 # 清理字符串: 移除方括号、换行符和多余空格 |                 # 清理字符串中的括号、换行等干扰符 | ||||||
|                 cleaned = eigenvalue_data.replace('[', '').replace(']', '').replace('\n', '').strip() |                 cleaned = eigenvalue_data.replace("[", "").replace("]", "").replace("\n", "").strip() | ||||||
|                 # 按空格或逗号分割(处理可能的不同分隔符) |                 # 分割并转换为浮点数数组 | ||||||
|                 values = [v for v in cleaned.split() if v] |                 values = [v for v in cleaned.split() if v]  # 兼容空格/逗号分隔 | ||||||
|                 # 转换为数组 |  | ||||||
|                 eigenvalue = np.array(list(map(float, values)), dtype=np.float32) |                 eigenvalue = np.array(list(map(float, values)), dtype=np.float32) | ||||||
|             else: |             else: | ||||||
|                 # 不支持的类型 |                 print(f"不支持的特征值类型({type(eigenvalue_data)}),跳过 {person_name}") | ||||||
|                 print(f"Unsupported eigenvalue type for {person_name}") |  | ||||||
|                 continue |                 continue | ||||||
|  |  | ||||||
|             # 归一化处理 |             # 特征值归一化(保证后续相似度计算的一致性) | ||||||
|             norm = np.linalg.norm(eigenvalue) |             norm = np.linalg.norm(eigenvalue) | ||||||
|             if norm != 0: |             if norm != 0: | ||||||
|                 eigenvalue = eigenvalue / norm |                 eigenvalue = eigenvalue / norm | ||||||
| @ -210,100 +224,103 @@ def load_model(prefer_gpu=True, preferred_gpus=[0, 1]): | |||||||
|             _known_faces_embeddings[person_name] = eigenvalue |             _known_faces_embeddings[person_name] = eigenvalue | ||||||
|             _known_faces_names.append(person_name) |             _known_faces_names.append(person_name) | ||||||
|  |  | ||||||
|  |         print(f"成功加载 {len(_known_faces_names)} 个人脸的特征库") | ||||||
|  |  | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"Error loading face data from service: {e}") |         print(f"加载人脸特征库失败: {e}") | ||||||
|  |  | ||||||
|     return True if _face_app else False |     return _face_app is not None | ||||||
|  |  | ||||||
|  |  | ||||||
| def detect(frame, threshold=0.4): | def detect(frame, similarity_threshold=0.4): | ||||||
|     """检测并识别人脸、返回结果元组(是否匹配到已知人脸, 结果字符串)""" |     """ | ||||||
|     global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id |     检测并识别人脸 | ||||||
|     global _ref_count, _last_used_time |     返回:(是否匹配到已知人脸, 结果描述字符串) | ||||||
|  |     """ | ||||||
|  |     global _face_app, _known_faces_embeddings, _known_faces_names, _ref_count, _last_used_time | ||||||
|  |  | ||||||
|     # 验证前置条件 |     # 校验输入帧有效性 | ||||||
|     if frame is None or frame.size == 0: |     if frame is None or frame.size == 0: | ||||||
|         return (False, "无效帧数据") |         return (False, "无效的输入帧数据") | ||||||
|  |  | ||||||
|     # 增加引用计数并获取引擎实例 |     # 加锁并更新引用计数、最后使用时间 | ||||||
|     engine = None |     engine = None | ||||||
|     with _lock: |     with _lock: | ||||||
|         _ref_count += 1 |         _ref_count += 1 | ||||||
|         _last_used_time = time.time() |         _last_used_time = time.time() | ||||||
|         _debug_counter["detected"] += 1 |         _debug_counter["detection_calls"] += 1 | ||||||
|  |  | ||||||
|         # 初始化引擎(如果未初始化且不在释放中) |         # 若引擎未初始化且未在释放中,尝试初始化 | ||||||
|         if not _face_app and not _is_releasing: |         if not _face_app and not _is_releasing: | ||||||
|             if not load_model(prefer_gpu=True): |             if not load_model(prefer_gpu=True): | ||||||
|                 _ref_count -= 1  # 恢复引用计数 |                 # 初始化失败,恢复引用计数 | ||||||
|                 return (False, "引擎初始化失败") |                 with _lock: | ||||||
|  |                     _ref_count = max(0, _ref_count - 1) | ||||||
|  |                 return (False, "人脸引擎初始化失败") | ||||||
|  |  | ||||||
|         # 获取当前引擎引用 |         engine = _face_app  # 获取引擎引用 | ||||||
|         engine = _face_app |  | ||||||
|  |  | ||||||
|     # 检查引擎是否可用 |     # 校验引擎可用性 | ||||||
|     if not engine or not _known_faces_names: |     if not engine or len(_known_faces_names) == 0: | ||||||
|         with _lock: |         with _lock: | ||||||
|             _ref_count = max(0, _ref_count - 1) |             _ref_count = max(0, _ref_count - 1) | ||||||
|         return (False, "人脸识别引擎不可用或未初始化") |         return (False, "人脸引擎不可用或特征库为空") | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         # 如果使用GPU、确保输入帧在处理前是连续的数组 |         # GPU计算时,确保帧数据是连续内存(避免CUDA错误) | ||||||
|         if _using_gpu and not frame.flags.contiguous: |         if _using_gpu and engine is not None and not frame.flags.contiguous: | ||||||
|             frame = np.ascontiguousarray(frame) |             frame = np.ascontiguousarray(frame) | ||||||
|  |  | ||||||
|         faces = _face_app.get(frame) |         # 执行人脸检测与特征提取 | ||||||
|  |         faces = engine.get(frame) | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"Face detect error: {e}") |         print(f"人脸检测过程出错: {e}") | ||||||
|         # 检测到错误时尝试重新选择GPU并重新初始化 |         # 出错时尝试重新初始化引擎(可能是GPU状态变化导致) | ||||||
|         print("尝试重新选择GPU并重新初始化...") |         print("尝试重新初始化人脸引擎...") | ||||||
|         with _lock: |         with _lock: | ||||||
|             _ref_count = max(0, _ref_count - 1) |             _ref_count = max(0, _ref_count - 1) | ||||||
|         load_model(prefer_gpu=True)  # 重新初始化时保持默认GPU优先级 |         load_model(prefer_gpu=True) | ||||||
|         return (False, f"检测错误: {str(e)}") |         return (False, f"检测错误: {str(e)}") | ||||||
|  |  | ||||||
|     result_parts = [] |     result_parts = [] | ||||||
|     has_matched = False  # 标记是否有匹配到的已知人脸 |     has_matched_known_face = False  # 是否有任意人脸匹配到已知库 | ||||||
|  |  | ||||||
|     for face in faces: |     for face in faces: | ||||||
|         # 特征归一化 |         # 归一化当前检测到的人脸特征 | ||||||
|         embedding = face.embedding.astype(np.float32) |         face_embedding = face.embedding.astype(np.float32) | ||||||
|         norm = np.linalg.norm(embedding) |         norm = np.linalg.norm(face_embedding) | ||||||
|         if norm == 0: |         if norm == 0: | ||||||
|             continue |             continue | ||||||
|         embedding = embedding / norm |         face_embedding = face_embedding / norm | ||||||
|  |  | ||||||
|         # 对比已知人脸 |         # 与已知人脸特征逐一比对 | ||||||
|         max_sim, best_name = -1.0, "Unknown" |         max_similarity, best_match_name = -1.0, "Unknown" | ||||||
|         for name in _known_faces_names: |         for name in _known_faces_names: | ||||||
|             known_emb = _known_faces_embeddings[name] |             known_emb = _known_faces_embeddings[name] | ||||||
|             sim = np.dot(embedding, known_emb) |             similarity = np.dot(face_embedding, known_emb)  # 余弦相似度 | ||||||
|             if sim > max_sim: |             if similarity > max_similarity: | ||||||
|                 max_sim = sim |                 max_similarity = similarity | ||||||
|                 best_name = name |                 best_match_name = name | ||||||
|  |  | ||||||
|         # 判断匹配结果 |         # 判断是否匹配成功 | ||||||
|         is_match = max_sim >= threshold |         is_matched = max_similarity >= similarity_threshold | ||||||
|         if is_match: |         if is_matched: | ||||||
|             has_matched = True  # 只要有一个匹配成功、就标记为True |             has_matched_known_face = True | ||||||
|  |  | ||||||
|         bbox = face.bbox |         # 记录该人脸的检测结果 | ||||||
|  |         bbox = face.bbox  # 人脸边界框 | ||||||
|         result_parts.append( |         result_parts.append( | ||||||
|             f"{'匹配' if is_match else '不匹配'}: {best_name} (相似度: {max_sim:.2f}, 边界框: {bbox})" |             f"{'匹配' if is_matched else '未匹配'}: {best_match_name} " | ||||||
|  |             f"(相似度: {max_similarity:.2f}, 边界框: {bbox.astype(int).tolist()})" | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     # 构建结果字符串 |     # 构建最终结果字符串 | ||||||
|     if not result_parts: |     result_str = "未检测到人脸" if not result_parts else "; ".join(result_parts) | ||||||
|         result_str = "未检测到人脸" |  | ||||||
|     else: |  | ||||||
|         result_str = "; ".join(result_parts) |  | ||||||
|  |  | ||||||
|     # 减少引用计数、确保线程安全 |     # 释放引用计数(线程安全) | ||||||
|     with _lock: |     with _lock: | ||||||
|         _ref_count = max(0, _ref_count - 1) |         _ref_count = max(0, _ref_count - 1) | ||||||
|         # 持续使用时更新最后使用时间 |         # 若仍有引用,更新最后使用时间;若引用为0,也立即标记(加快闲置检测) | ||||||
|         if _ref_count > 0: |         _last_used_time = time.time() | ||||||
|             _last_used_time = time.time() |  | ||||||
|  |  | ||||||
|     # 第一个返回值为: 是否匹配到已知人脸 |     return (has_matched_known_face, result_str) | ||||||
|     return (has_matched, result_str) |  | ||||||
							
								
								
									
										14
									
								
								core/ocr.py
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								core/ocr.py
									
									
									
									
									
								
							| @ -167,7 +167,19 @@ def detect(frame): | |||||||
|                 items_to_process = [line] |                 items_to_process = [line] | ||||||
|  |  | ||||||
|             for item in items_to_process: |             for item in items_to_process: | ||||||
|                 # 跳过纯数字列表(可能是坐标信息) |                 # 精确识别并忽略图片坐标位置信息 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] | ||||||
|  |                 if isinstance(item, list) and len(item) == 4:  # 四边形有4个顶点 | ||||||
|  |                     is_coordinate = True | ||||||
|  |                     for point in item: | ||||||
|  |                         # 每个顶点应该是包含2个数字的列表 | ||||||
|  |                         if not (isinstance(point, list) and len(point) == 2 and | ||||||
|  |                                 all(isinstance(coord, (int, float)) for coord in point)): | ||||||
|  |                             is_coordinate = False | ||||||
|  |                             break | ||||||
|  |                     if is_coordinate: | ||||||
|  |                         continue  # 是坐标信息,直接忽略 | ||||||
|  |  | ||||||
|  |                 # 跳过纯数字列表(其他可能的坐标形式) | ||||||
|                 if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item): |                 if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item): | ||||||
|                     continue |                     continue | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										4
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								main.py
									
									
									
									
									
								
							| @ -12,7 +12,7 @@ from service.sensitive_service import router as sensitive_router | |||||||
| from service.face_service import router as face_router | from service.face_service import router as face_router | ||||||
| from service.device_service import router as device_router | from service.device_service import router as device_router | ||||||
| from ws.ws import ws_router, lifespan | from ws.ws import ws_router, lifespan | ||||||
|  | from core.establish import create_directory_structure | ||||||
|  |  | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| # 初始化 FastAPI 应用、指定生命周期管理 | # 初始化 FastAPI 应用、指定生命周期管理 | ||||||
| @ -47,6 +47,8 @@ if __name__ == "__main__": | |||||||
|     YOLO_MODEL_PATH = r"/core/models\best.pt" |     YOLO_MODEL_PATH = r"/core/models\best.pt" | ||||||
|     OCR_CONFIG_PATH = r"/core/config\config.yaml" |     OCR_CONFIG_PATH = r"/core/config\config.yaml" | ||||||
|  |  | ||||||
|  |     create_directory_structure() | ||||||
|  |  | ||||||
|     # 初始化项目(默认端口设为8000、避免初始化失败时port未定义) |     # 初始化项目(默认端口设为8000、避免初始化失败时port未定义) | ||||||
|     port = int(SERVER_CONFIG.get("port", 8000)) |     port = int(SERVER_CONFIG.get("port", 8000)) | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										91
									
								
								ws/ws.py
									
									
									
									
									
								
							
							
						
						
									
										91
									
								
								ws/ws.py
									
									
									
									
									
								
							| @ -3,12 +3,11 @@ import datetime | |||||||
| import json | import json | ||||||
| import os | import os | ||||||
| from contextlib import asynccontextmanager | from contextlib import asynccontextmanager | ||||||
| from typing import Dict, Optional, AsyncGenerator | from typing import Dict, Optional | ||||||
| from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip | from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip | ||||||
| from service.device_action_service import add_device_action | from service.device_action_service import add_device_action | ||||||
| from schema.device_action_schema import DeviceActionCreate | from schema.device_action_schema import DeviceActionCreate | ||||||
| # 【修改1:导入detect和TIMEOUT(用于检测超时控制)】 | from core.all import detect, load_model | ||||||
| from core.all import detect, load_model, TIMEOUT |  | ||||||
|  |  | ||||||
| import cv2 | import cv2 | ||||||
| import numpy as np | import numpy as np | ||||||
| @ -21,7 +20,7 @@ WS_ENDPOINT = "/ws"  # WebSocket端点路径 | |||||||
| FRAME_QUEUE_SIZE = 1  # 帧队列大小限制 | FRAME_QUEUE_SIZE = 1  # 帧队列大小限制 | ||||||
|  |  | ||||||
|  |  | ||||||
| # 工具函数: 获取格式化时间字符串(统一时间戳格式) | # 工具函数: 获取格式化时间字符串 | ||||||
| def get_current_time_str() -> str: | def get_current_time_str() -> str: | ||||||
|     return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") |     return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") | ||||||
|  |  | ||||||
| @ -40,13 +39,13 @@ class ClientConnection: | |||||||
|         self.consumer_task: Optional[asyncio.Task] = None |         self.consumer_task: Optional[asyncio.Task] = None | ||||||
|  |  | ||||||
|     def update_heartbeat(self): |     def update_heartbeat(self): | ||||||
|         """更新心跳时间(客户端发送心跳时调用)""" |         """更新心跳时间""" | ||||||
|         self.last_heartbeat = datetime.datetime.now() |         self.last_heartbeat = datetime.datetime.now() | ||||||
|  |  | ||||||
|     def is_alive(self) -> bool: |     def is_alive(self) -> bool: | ||||||
|         """判断客户端是否存活(心跳超时检查)""" |         """判断客户端是否存活""" | ||||||
|         timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds() |         timeout_seconds = (datetime.datetime.now() - self.last_heartbeat).total_seconds() | ||||||
|         return timeout < HEARTBEAT_TIMEOUT |         return timeout_seconds < HEARTBEAT_TIMEOUT | ||||||
|  |  | ||||||
|     def start_consumer(self): |     def start_consumer(self): | ||||||
|         """启动帧消费任务""" |         """启动帧消费任务""" | ||||||
| @ -54,10 +53,7 @@ class ClientConnection: | |||||||
|         return self.consumer_task |         return self.consumer_task | ||||||
|  |  | ||||||
|     async def send_frame_permit(self): |     async def send_frame_permit(self): | ||||||
|         """ |         """发送帧发送许可信号""" | ||||||
|         发送「帧发送许可信号」 |  | ||||||
|         通知客户端可发送下一帧图像 |  | ||||||
|         """ |  | ||||||
|         try: |         try: | ||||||
|             frame_permit_msg = { |             frame_permit_msg = { | ||||||
|                 "type": "frame", |                 "type": "frame", | ||||||
| @ -65,26 +61,21 @@ class ClientConnection: | |||||||
|                 "client_ip": self.client_ip |                 "client_ip": self.client_ip | ||||||
|             } |             } | ||||||
|             await self.websocket.send_json(frame_permit_msg) |             await self.websocket.send_json(frame_permit_msg) | ||||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送帧发送许可信号(取帧后立即通知)") |             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送帧发送许可信号") | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可信号发送失败 - {str(e)}") |             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可信号发送失败 - {str(e)}") | ||||||
|  |  | ||||||
|     async def consume_frames(self) -> None: |     async def consume_frames(self) -> None: | ||||||
|         """消费队列中的帧并处理(核心调整: 取帧后立即发许可、再处理帧)""" |         """消费队列中的帧并处理""" | ||||||
|         try: |         try: | ||||||
|             while True: |             while True: | ||||||
|                 # 1. 从队列取出帧(阻塞直到有帧可用) |                 # 取出帧并立即发送下一帧许可 | ||||||
|                 frame_data = await self.frame_queue.get() |                 frame_data = await self.frame_queue.get() | ||||||
|  |                 await self.send_frame_permit() | ||||||
|                 # -------------------------- 核心修改: 取出帧后立即发送下一帧许可 -------------------------- |  | ||||||
|                 await self.send_frame_permit()  # 取帧即通知客户端发下一帧、无需等处理完成 |  | ||||||
|                 # ----------------------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
|                 try: |                 try: | ||||||
|                     # 2. 处理取出的帧(即使处理慢、客户端也已收到许可、可提前准备下一帧) |  | ||||||
|                     await self.process_frame(frame_data) |                     await self.process_frame(frame_data) | ||||||
|                 finally: |                 finally: | ||||||
|                     # 3. 标记帧任务完成(无论处理成功/失败、都需清理队列) |  | ||||||
|                     self.frame_queue.task_done() |                     self.frame_queue.task_done() | ||||||
|  |  | ||||||
|         except asyncio.CancelledError: |         except asyncio.CancelledError: | ||||||
| @ -93,8 +84,8 @@ class ClientConnection: | |||||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}") |             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}") | ||||||
|  |  | ||||||
|     async def process_frame(self, frame_data: bytes) -> None: |     async def process_frame(self, frame_data: bytes) -> None: | ||||||
|         """处理单帧图像数据""" |         """处理单帧图像数据(核心修复:按3个返回值解包)""" | ||||||
|         # 二进制数据转OpenCV图像 |         # 二进制转OpenCV图像 | ||||||
|         nparr = np.frombuffer(frame_data, np.uint8) |         nparr = np.frombuffer(frame_data, np.uint8) | ||||||
|         img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |         img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | ||||||
|         if img is None: |         if img is None: | ||||||
| @ -102,52 +93,41 @@ class ClientConnection: | |||||||
|             return |             return | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             # -------------------------- 提交检测任务并等待结果 -------------------------- |             # -------------------------- 修复核心:匹配detect返回的3个值 -------------------------- | ||||||
|             # 1. 提交检测任务获取Future对象(非阻塞) |             # 假设detect返回 (是否违规, 结果数据, 检测器类型) | ||||||
|             detection_future = detect(img) |             has_violation, data, detector_type = await asyncio.to_thread( | ||||||
|             # 2. 用asyncio.to_thread等待Future结果(避免阻塞asyncio事件循环),设置超时 |                 detect,  # 调用检测函数 | ||||||
|             try: |                 img      # 传入图像参数 | ||||||
|                 # 解包4元素结果:(是否违规, 结果数据, 检测器类型, 任务ID) |             ) | ||||||
|                 has_violation, data, detector_type, task_id = await asyncio.to_thread( |             # ------------------------------------------------------------------------------------- | ||||||
|                     detection_future.result,  # 调用Future的result()获取实际结果 |  | ||||||
|                     timeout=TIMEOUT  # 超时控制(与all.py配置一致) |  | ||||||
|                 ) |  | ||||||
|             except TimeoutError: |  | ||||||
|                 # 处理检测超时场景 |  | ||||||
|                 print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测任务超时(超过{TIMEOUT}秒)") |  | ||||||
|                 has_violation = False |  | ||||||
|                 data = f"检测超时(超过{TIMEOUT}秒)" |  | ||||||
|                 detector_type = "timeout" |  | ||||||
|                 task_id = -1  # 超时任务ID标记为-1 |  | ||||||
|             # ----------------------------------------------------------------------------------------- |  | ||||||
|  |  | ||||||
|             # 打印检测结果 |             # 打印检测结果(移除task_id相关内容) | ||||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - " |             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - " | ||||||
|                   f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}, 任务ID: {task_id}") |                   f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}") | ||||||
|  |  | ||||||
|             # 处理违规逻辑 |             # 处理违规逻辑 | ||||||
|             if has_violation: |             if has_violation: | ||||||
|                 print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - " |                 print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - " | ||||||
|                       f"类型: {detector_type}, 详情: {data}") |                       f"类型: {detector_type}, 详情: {data}") | ||||||
|  |  | ||||||
|                 # 调用违规次数加一方法 |                 # 违规次数+1 | ||||||
|                 try: |                 try: | ||||||
|                     await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip) |                     await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip) | ||||||
|                     print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数已+1") |                     print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数已+1") | ||||||
|                 except Exception as e: |                 except Exception as e: | ||||||
|                     print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}") |                     print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}") | ||||||
|  |  | ||||||
|                 # 发送「危险通知」 |                 # 发送危险通知 | ||||||
|                 danger_msg = { |                 danger_msg = { | ||||||
|                     "type": "danger", |                     "type": "danger", | ||||||
|                     "timestamp": get_current_time_str(), |                     "timestamp": get_current_time_str(), | ||||||
|                     "client_ip": self.client_ip |                     "client_ip": self.client_ip, | ||||||
|  |                     "detail": data | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 # TODO 数据存储到数据库 |  | ||||||
|                 await self.websocket.send_json(danger_msg) |                 await self.websocket.send_json(danger_msg) | ||||||
|             else: |             else: | ||||||
|                 print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 未检测到违规") |                 print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 未检测到违规") | ||||||
|  |  | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}") |             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}") | ||||||
|  |  | ||||||
| @ -157,7 +137,7 @@ connected_clients: Dict[str, ClientConnection] = {} | |||||||
| heartbeat_task: Optional[asyncio.Task] = None | heartbeat_task: Optional[asyncio.Task] = None | ||||||
|  |  | ||||||
|  |  | ||||||
| # 心跳检查(定时清理超时客户端 + 调用离线状态更新方法) | # 心跳检查任务 | ||||||
| async def heartbeat_checker(): | async def heartbeat_checker(): | ||||||
|     while True: |     while True: | ||||||
|         current_time = get_current_time_str() |         current_time = get_current_time_str() | ||||||
| @ -172,7 +152,7 @@ async def heartbeat_checker(): | |||||||
|                         conn.consumer_task.cancel() |                         conn.consumer_task.cancel() | ||||||
|                     await conn.websocket.close(code=1008, reason="心跳超时") |                     await conn.websocket.close(code=1008, reason="心跳超时") | ||||||
|  |  | ||||||
|                     # 超时设为离线并记录 |                     # 标记离线 | ||||||
|                     try: |                     try: | ||||||
|                         await asyncio.to_thread(update_online_status_by_ip, ip, 0) |                         await asyncio.to_thread(update_online_status_by_ip, ip, 0) | ||||||
|                         action_data = DeviceActionCreate(client_ip=ip, action=0) |                         action_data = DeviceActionCreate(client_ip=ip, action=0) | ||||||
| @ -247,7 +227,6 @@ ws_router = APIRouter() | |||||||
|  |  | ||||||
| @ws_router.websocket(WS_ENDPOINT) | @ws_router.websocket(WS_ENDPOINT) | ||||||
| async def websocket_endpoint(websocket: WebSocket): | async def websocket_endpoint(websocket: WebSocket): | ||||||
|     # 加载模型(首次连接时自动加载,线程安全) |  | ||||||
|     load_model() |     load_model() | ||||||
|     await websocket.accept() |     await websocket.accept() | ||||||
|     client_ip = websocket.client.host if websocket.client else "unknown_ip" |     client_ip = websocket.client.host if websocket.client else "unknown_ip" | ||||||
| @ -257,7 +236,7 @@ async def websocket_endpoint(websocket: WebSocket): | |||||||
|     is_online_updated = False |     is_online_updated = False | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         # 处理重复连接(关闭同一IP的旧连接) |         # 处理重复连接 | ||||||
|         if client_ip in connected_clients: |         if client_ip in connected_clients: | ||||||
|             old_conn = connected_clients[client_ip] |             old_conn = connected_clients[client_ip] | ||||||
|             if old_conn.consumer_task and not old_conn.consumer_task.done(): |             if old_conn.consumer_task and not old_conn.consumer_task.done(): | ||||||
| @ -270,10 +249,9 @@ async def websocket_endpoint(websocket: WebSocket): | |||||||
|         new_conn = ClientConnection(websocket, client_ip) |         new_conn = ClientConnection(websocket, client_ip) | ||||||
|         connected_clients[client_ip] = new_conn |         connected_clients[client_ip] = new_conn | ||||||
|         new_conn.start_consumer() |         new_conn.start_consumer() | ||||||
|         # 初始许可: 连接建立后立即发一次、让客户端知道可发第一帧 |  | ||||||
|         await new_conn.send_frame_permit() |         await new_conn.send_frame_permit() | ||||||
|  |  | ||||||
|         # 标记上线并记录 |         # 标记上线 | ||||||
|         try: |         try: | ||||||
|             await asyncio.to_thread(update_online_status_by_ip, client_ip, 1) |             await asyncio.to_thread(update_online_status_by_ip, client_ip, 1) | ||||||
|             action_data = DeviceActionCreate(client_ip=client_ip, action=1) |             action_data = DeviceActionCreate(client_ip=client_ip, action=1) | ||||||
| @ -285,7 +263,7 @@ async def websocket_endpoint(websocket: WebSocket): | |||||||
|  |  | ||||||
|         print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}") |         print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}") | ||||||
|  |  | ||||||
|         # 消息循环(接收客户端文本/二进制消息) |         # 消息循环 | ||||||
|         while True: |         while True: | ||||||
|             data = await websocket.receive() |             data = await websocket.receive() | ||||||
|             if "text" in data: |             if "text" in data: | ||||||
| @ -298,13 +276,12 @@ async def websocket_endpoint(websocket: WebSocket): | |||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}") |         print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}") | ||||||
|     finally: |     finally: | ||||||
|         # 清理资源并标记离线 |         # 清理资源 | ||||||
|         if client_ip in connected_clients: |         if client_ip in connected_clients: | ||||||
|             conn = connected_clients[client_ip] |             conn = connected_clients[client_ip] | ||||||
|             if conn.consumer_task and not conn.consumer_task.done(): |             if conn.consumer_task and not conn.consumer_task.done(): | ||||||
|                 conn.consumer_task.cancel() |                 conn.consumer_task.cancel() | ||||||
|  |  | ||||||
|             # 主动/异常断开时标记离线(仅当上线状态更新成功时) |  | ||||||
|             if is_online_updated: |             if is_online_updated: | ||||||
|                 try: |                 try: | ||||||
|                     await asyncio.to_thread(update_online_status_by_ip, client_ip, 0) |                     await asyncio.to_thread(update_online_status_by_ip, client_ip, 0) | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user