diff --git a/core/all.py b/core/all.py index 4dd1ef9..b070a2a 100644 --- a/core/all.py +++ b/core/all.py @@ -1,139 +1,70 @@ +import cv2 from core.ocr import load_model as ocrLoadModel, detect as ocrDetect from core.face import load_model as faceLoadModel, detect as faceDetect from core.yolo import load_model as yoloLoadModel, detect as yoloDetect -from concurrent.futures import ThreadPoolExecutor, Future -import threading -import cv2 -import numpy as np - -# -------------------------- 核心配置参数 -------------------------- -MAX_WORKERS = 6 # 线程池最大线程数 -DETECTION_ORDER = ["yolo", "face", "ocr"] # 检测执行顺序 -TIMEOUT = 30 # 检测超时时间(秒) 【确保此常量可被外部导入】 - -# -------------------------- 全局状态管理 -------------------------- -_executor = None # 线程池实例 -_model_loaded = False # 模型加载状态标记 -_model_lock = threading.Lock() # 模型加载线程锁 -_executor_lock = threading.Lock() # 线程池初始化锁 -_task_counter = 0 # 任务计数器 -_task_counter_lock = threading.Lock() # 任务计数锁 +# 导入保存路径函数(根据实际文件位置调整导入路径) +from core.establish import get_image_save_path +# 模型加载状态标记(避免重复加载) -# -------------------------- 工具函数 -------------------------- -def _get_next_task_id(): - """获取唯一任务ID、用于日志追踪""" - global _task_counter - with _task_counter_lock: - _task_counter += 1 - return _task_counter +_model_loaded = False -# -------------------------- 模型加载 -------------------------- def load_model(): - """加载所有检测模型并初始化线程池(仅执行一次)""" + """加载所有检测模型(仅首次调用时执行)""" global _model_loaded - if not _model_loaded: - with _model_lock: - if not _model_loaded: - print("=== 开始加载检测模型 ===") + if _model_loaded: + print("模型已加载,无需重复执行") + return - # 按顺序加载模型 - print("加载YOLO模型...") - yoloLoadModel() + # 依次加载OCR、人脸、YOLO模型 + ocrLoadModel() + faceLoadModel() + yoloLoadModel() - print("加载人脸检测模型...") - faceLoadModel() - - print("加载OCR模型...") - ocrLoadModel() - - _model_loaded = True - print("=== 所有模型加载完成 ===") - - # 初始化线程池 - _init_thread_pool() + _model_loaded = True + print("所有检测模型加载完成") -# -------------------------- 线程池管理 -------------------------- -def _init_thread_pool(): - """初始化线程池(仅内部调用)""" - global _executor - with _executor_lock: - if _executor is None: - _executor = ThreadPoolExecutor( - max_workers=MAX_WORKERS, - thread_name_prefix="DetectionThread" - ) - print(f"=== 线程池初始化完成、最大线程数: {MAX_WORKERS} ===") - - -def shutdown(): - """关闭线程池、释放资源""" - global _executor - with _executor_lock: - if _executor is not None: - _executor.shutdown(wait=True) - _executor = None - print("=== 线程池已安全关闭 ===") - - -# -------------------------- 检测逻辑实现 -------------------------- -def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple: - """在子线程中执行检测逻辑(返回4元素tuple:检测是否成功、结果数据、检测器类型、任务ID)""" - thread_name = threading.current_thread().name - print(f"任务[{task_id}] 开始执行、线程: {thread_name}") - - try: - # 按照配置顺序执行检测 - for detector in DETECTION_ORDER: - if detector == "yolo": - success, result = yoloDetect(frame) - elif detector == "face": - success, result = faceDetect(frame) - elif detector == "ocr": - success, result = ocrDetect(frame) - else: - success, result = False, None - - print(f"任务[{task_id}] {detector}检测状态: {'成功' if success else '未检测到内容'}") - if success: - print(f"任务[{task_id}] 完成检测、使用检测器: {detector}") - return (success, result, detector, task_id) # 4元素tuple - - # 所有检测器均未检测到结果 - print(f"任务[{task_id}] 所有检测器均未检测到有效内容") - return (False, "未检测到任何有效内容", "none", task_id) # 4元素tuple - - except Exception as e: - print(f"任务[{task_id}] 检测过程发生错误: {str(e)}") - return (False, f"检测错误: {str(e)}", "error", task_id) # 4元素tuple - - -# -------------------------- 外部调用接口 -------------------------- -def detect(frame: np.ndarray) -> Future: +def detect(frame): """ - 提交检测任务到线程池(返回Future对象,需调用result()获取4元素结果) - - 参数: - frame: 待检测图像(ndarray格式、cv2.imdecode生成) - - 返回: - Future对象、result()返回tuple: (success, data, detector_type, task_id) - success: 布尔值,表示是否检测到有效内容 - data: 检测结果数据(成功时为具体结果,失败时为错误信息) - detector_type: 使用的检测器类型("yolo"/"face"/"ocr"/"none"/"error") - task_id: 任务唯一标识 + 执行模型检测,检测到违规时按指定格式保存图片 + 参数: + frame: 待检测的图像帧(OpenCV格式,numpy.ndarray类型) + 返回: + (检测结果布尔值, 检测详情, 检测模型类型) """ - # 确保模型已加载 - if not _model_loaded: - print("警告: 模型尚未加载、将自动加载") - load_model() + # 1. YOLO检测(优先级1) + yolo_flag, yolo_result = yoloDetect(frame) + print(f"YOLO检测结果:{yolo_result}") + if yolo_flag: + # 直接调用路径生成函数,无需传入原始图片名 + save_path = get_image_save_path(model_type="yolo") + if save_path: + cv2.imwrite(save_path, frame) + print(f"✅ YOLO违规图片已保存:{save_path}") + return (True, yolo_result, "yolo") - # 生成任务ID - task_id = _get_next_task_id() + # 2. 人脸检测(优先级2) + face_flag, face_result = faceDetect(frame) + print(f"人脸检测结果:{face_result}") + if face_flag: + save_path = get_image_save_path(model_type="face") + if save_path: + cv2.imwrite(save_path, frame) + print(f"✅ 人脸违规图片已保存:{save_path}") + return (True, face_result, "face") - # 提交任务到线程池(返回Future) - future = _executor.submit(_detect_in_thread, frame, task_id) - print(f"任务[{task_id}]: 已提交到线程池") - return future + # 3. OCR检测(优先级3) + ocr_flag, ocr_result = ocrDetect(frame) + print(f"OCR检测结果:{ocr_result}") + if ocr_flag: + save_path = get_image_save_path(model_type="ocr") + if save_path: + cv2.imwrite(save_path, frame) + print(f"✅ OCR违规图片已保存:{save_path}") + return (True, ocr_result, "ocr") + + # 4. 无违规内容(不保存图片) + print(f"❌ 未检测到任何违规内容,不保存图片") + return (False, "未检测到任何内容", "none") \ No newline at end of file diff --git a/core/establish.py b/core/establish.py new file mode 100644 index 0000000..4b8027c --- /dev/null +++ b/core/establish.py @@ -0,0 +1,111 @@ +import os +import datetime +from pathlib import Path + + +# 配置IP文件路径(统一使用绝对路径) +IP_FILE_PATH = Path(r"D:\ccc\IP.txt") + + +def create_directory_structure(): + """创建项目所需的目录结构""" + try: + # 1. 创建根目录下的resource文件夹 + resource_dir = Path("resource") + resource_dir.mkdir(exist_ok=True) + print(f"确保resource目录存在: {resource_dir.absolute()}") + + # 2. 在resource下创建dect文件夹 + dect_dir = resource_dir / "dect" + dect_dir.mkdir(exist_ok=True) + print(f"确保dect目录存在: {dect_dir.absolute()}") + + # 3. 在dect下创建三个模型文件夹 + model_dirs = ["ocr", "face", "yolo"] + for model in model_dirs: + model_dir = dect_dir / model + model_dir.mkdir(exist_ok=True) + print(f"确保{model}模型目录存在: {model_dir.absolute()}") + + # 4. 读取ip.txt文件获取IP地址 + try: + with open(IP_FILE_PATH, "r") as f: + ip_addresses = [line.strip() for line in f if line.strip()] + + if not ip_addresses: + print("警告: ip.txt文件中未找到有效的IP地址") + return + + print(f"从ip.txt中读取到的IP地址: {ip_addresses}") + + # 5. 获取当前日期 + now = datetime.datetime.now() + current_year = str(now.year) + current_month = str(now.month) + + # 6. 为每个IP在每个模型文件夹下创建年->月的目录结构 + for ip in ip_addresses: + # 处理IP地址中的特殊字符(如果有) + safe_ip = ip.replace(".", "_") + + for model in model_dirs: + # 构建路径: resource/dect/{model}/{ip}/{year}/{month} + ip_dir = dect_dir / model / safe_ip + year_dir = ip_dir / current_year + month_dir = year_dir / current_month + + # 创建目录(如果不存在) + month_dir.mkdir(parents=True, exist_ok=True) + print(f"创建/确保目录存在: {month_dir.absolute()}") + + except FileNotFoundError: + print(f"错误: 未找到ip.txt文件,请确保该文件存在于 {IP_FILE_PATH}") + except Exception as e: + print(f"处理IP和日期目录时发生错误: {str(e)}") + + except Exception as e: + print(f"创建目录结构时发生错误: {str(e)}") + + +def get_image_save_path(model_type: str) -> str: + """ + 获取图片保存的完整路径(不依赖原始图片名称) + + 参数: + model_type: 模型类型,应为"ocr"、"face"或"yolo" + + 返回: + 完整的图片保存路径 + """ + try: + # 读取IP地址(假设只有一个IP或使用第一个IP) + with open(IP_FILE_PATH, "r") as f: + ip_addresses = [line.strip() for line in f if line.strip()] + + if not ip_addresses: + raise ValueError("ip.txt文件中未找到有效的IP地址") + + ip = ip_addresses[0] + safe_ip = ip.replace(".", "_") + + # 获取当前日期和时间(精确到毫秒,确保文件名唯一) + now = datetime.datetime.now() + current_year = str(now.year) + current_month = str(now.month) + current_day = str(now.day) + # 生成时间戳字符串(格式:年月日时分秒毫秒) + timestamp = now.strftime("%Y%m%d%H%M%S%f")[:-3] # 去除最后三位,保留到毫秒 + + # 构建路径: resource/dect/{model}/{ip}/{year}/{month}/{day} + day_dir = Path("resource") / "dect" / model_type / safe_ip / current_year / current_month / current_day + day_dir.mkdir(parents=True, exist_ok=True) + + # 构建图片文件名(使用时间戳确保唯一性) + image_filename = f"resource_dect_{model_type}_{safe_ip}_{current_year}_{current_month}_{current_day}_{timestamp}.jpg" + image_path = day_dir / image_filename + + return str(image_path) + + except Exception as e: + print(f"获取图片保存路径时发生错误: {str(e)}") + return "" diff --git a/core/face.py b/core/face.py index 4306d4e..b8e972a 100644 --- a/core/face.py +++ b/core/face.py @@ -6,203 +6,217 @@ import time import threading from PIL import Image from insightface.app import FaceAnalysis -# 导入获取人脸信息的服务 +# 假设service.face_service中get_all_face_name_with_eigenvalue可获取人脸数据 from service.face_service import get_all_face_name_with_eigenvalue -# 用于检查GPU状态 +# GPU状态检查支持 try: import pynvml pynvml.nvmlInit() _nvml_available = True except ImportError: - print("警告: pynvml库未安装、无法检测GPU状态、将默认使用0号GPU") + print("警告: pynvml库未安装,无法检测GPU状态,默认尝试使用GPU") _nvml_available = False -# 全局变量 +# 全局人脸引擎与特征库 _face_app = None -_known_faces_embeddings = {} # 存储姓名到特征值的映射 -_known_faces_names = [] # 存储所有已知姓名 -_using_gpu = False # 标记是否使用GPU -_used_gpu_id = -1 # 记录当前使用的GPU ID +_known_faces_embeddings = {} # 姓名 -> 归一化特征值的映射 +_known_faces_names = [] # 已知人脸姓名列表 + +# GPU使用状态标记 +_using_gpu = False # 是否使用GPU +_used_gpu_id = -1 # 使用的GPU ID(-1表示CPU) # 资源管理变量 -_ref_count = 0 -_last_used_time = 0 -_lock = threading.Lock() -_release_timeout = 8 # 5秒无使用则释放 -_is_releasing = False # 标记是否正在释放 +_ref_count = 0 # 引擎引用计数(记录当前使用次数) +_last_used_time = 0 # 最后一次使用引擎的时间 +_lock = threading.Lock() # 线程安全锁 +_release_timeout = 8 # 闲置超时时间(秒) +_is_releasing = False # 资源释放中标记 +_monitor_thread_running = False # 监控线程运行标记 -# 调试用计数器 +# 调试计数器 _debug_counter = { - "created": 0, - "released": 0, - "detected": 0 + "engine_created": 0, # 引擎创建次数 + "engine_released": 0, # 引擎释放次数 + "detection_calls": 0 # 检测函数调用次数 } -def check_gpu_availability(gpu_id, threshold=0.7): - """检查指定GPU是否可用(内存使用率低于阈值)""" +def check_gpu_availability(gpu_id, memory_threshold=0.7): + """检查指定GPU的内存使用率是否低于阈值(判定为“可用”)""" if not _nvml_available: - return True # 无法检测时默认认为可用 - + return True try: handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id) mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) - usage = mem_info.used / mem_info.total - # 内存使用率低于阈值则认为可用 - return usage < threshold + memory_usage = mem_info.used / mem_info.total + return memory_usage < memory_threshold except Exception as e: - print(f"检查GPU {gpu_id} 状态时出错: {e}") + print(f"检查GPU {gpu_id} 状态失败: {e}") return False def select_best_gpu(preferred_gpus=[0, 1]): - """选择最佳可用GPU、严格按照首选列表顺序检查、优先使用0号GPU""" - # 首先检查首选GPU列表 + """按优先级选择可用GPU,优先0号;均不可用则返回-1(CPU)""" for gpu_id in preferred_gpus: try: - # 检查GPU是否存在 + # 验证GPU是否存在 if _nvml_available: pynvml.nvmlDeviceGetHandleByIndex(gpu_id) - - # 检查GPU是否可用 + # 验证GPU内存是否充足 if check_gpu_availability(gpu_id): - print(f"GPU {gpu_id} 可用、将使用该GPU") + print(f"GPU {gpu_id} 可用,将使用该GPU") return gpu_id else: if gpu_id == 0: - print(f"GPU 0 内存使用率过高(繁忙)、尝试切换到其他GPU") + print("GPU 0 内存使用率过高,尝试其他GPU") except Exception as e: - print(f"GPU {gpu_id} 不存在或无法访问: {e}") - continue - - # 如果所有首选GPU都不可用、返回-1表示使用CPU - print("所有指定的GPU都不可用、将使用CPU进行计算") + print(f"GPU {gpu_id} 不可用或访问失败: {e}") + print("所有指定GPU均不可用,将使用CPU计算") return -1 -def _release_engine(): - """释放人脸识别引擎资源""" +def _release_engine_resources(): + """释放人脸引擎的所有资源(模型、特征库、GPU缓存等)""" global _face_app, _is_releasing, _known_faces_embeddings, _known_faces_names if not _face_app or _is_releasing: return try: _is_releasing = True - # 释放InsightFace资源 - if hasattr(_face_app, 'model'): - # 清除模型资源 - _face_app.model = None - _face_app = None + print("开始释放人脸引擎资源...") - # 清空人脸数据 + # 释放InsightFace模型资源 + if hasattr(_face_app, "model"): + _face_app.model = None # 显式置空模型引用 + _face_app = None # 释放引擎实例 + + # 清空人脸特征库 _known_faces_embeddings.clear() _known_faces_names.clear() - _debug_counter["released"] += 1 - print(f"Face recognition engine released. Stats: {_debug_counter}") + _debug_counter["engine_released"] += 1 + print(f"人脸引擎已释放,调试统计: {_debug_counter}") - # 清理GPU缓存 + # 强制垃圾回收 gc.collect() + + # 清理各深度学习框架的GPU缓存 + # Torch 缓存清理 try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() + print("Torch GPU缓存已清理") except ImportError: pass + + # TensorFlow 缓存清理 try: import tensorflow as tf tf.keras.backend.clear_session() + print("TensorFlow会话已清理") except ImportError: pass + + # MXNet 缓存清理(InsightFace底层常用MXNet) + try: + import mxnet as mx + mx.nd.waitall() # 等待所有计算完成并释放资源 + print("MXNet资源已等待释放") + except ImportError: + pass + + except Exception as e: + print(f"释放资源过程中出错: {e}") finally: _is_releasing = False -def _monitor_thread(): - """监控线程、检查并释放超时未使用的资源""" - global _ref_count, _last_used_time, _face_app - while True: - time.sleep(5) # 每5秒检查一次 +def _resource_monitor_thread(): + """后台监控线程:检测引擎闲置超时,触发资源释放""" + global _ref_count, _last_used_time, _face_app, _monitor_thread_running + _monitor_thread_running = True + while _monitor_thread_running: + time.sleep(2) # 缩短检查间隔,加快闲置检测响应 with _lock: - # 只有当引擎存在、没有引用且超时、才释放 + # 当“引擎存在 + 无引用 + 未在释放中”时,检查闲置时间 if _face_app and _ref_count == 0 and not _is_releasing: - elapsed = time.time() - _last_used_time - if elapsed > _release_timeout: - print(f"Idle timeout ({elapsed:.1f}s > {_release_timeout}s), releasing face engine") - _release_engine() + idle_time = time.time() - _last_used_time + if idle_time > _release_timeout: + print(f"引擎闲置超时({idle_time:.1f}s > {_release_timeout}s),释放资源") + _release_engine_resources() def load_model(prefer_gpu=True, preferred_gpus=[0, 1]): - """加载人脸识别模型及已知人脸特征库、默认优先使用0号GPU""" + """加载人脸识别引擎及已知人脸特征库(默认优先用0号GPU)""" global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id - # 确保监控线程只启动一次 - if not any(t.name == "FaceMonitor" for t in threading.enumerate()): - threading.Thread(target=_monitor_thread, daemon=True, name="FaceMonitor").start() - print("Face monitor thread started") + # 启动后台监控线程(确保仅启动一次) + if not _monitor_thread_running: + threading.Thread( + target=_resource_monitor_thread, + daemon=True, + name="FaceEngineMonitor" + ).start() + print("人脸引擎监控线程已启动") - # 如果正在释放中、等待释放完成 + # 若正在释放资源,等待释放完成 while _is_releasing: time.sleep(0.1) - # 如果已经初始化、直接返回 + # 若引擎已初始化,直接返回 if _face_app: return True - # 初始化InsightFace模型 + # 初始化InsightFace引擎 try: - # 初始化InsightFace print("正在初始化InsightFace人脸识别引擎...") - _face_app = FaceAnalysis(name='buffalo_l', root='~/.insightface') + _face_app = FaceAnalysis(name="buffalo_l", root=os.path.expanduser("~/.insightface")) - # 选择合适的GPU、默认优先使用0号 + # 选择GPU(优先用0号) ctx_id = 0 if prefer_gpu: ctx_id = select_best_gpu(preferred_gpus) _using_gpu = ctx_id != -1 _used_gpu_id = ctx_id if _using_gpu else -1 - if _using_gpu: - print(f"成功初始化、使用GPU {ctx_id} 进行计算") - else: - print("成功初始化、使用CPU进行计算") + if _using_gpu: + print(f"引擎初始化成功,将使用GPU {ctx_id} 计算") + else: + print("引擎初始化成功,将使用CPU计算") - # 准备模型 + # 准备模型(加载到指定设备) _face_app.prepare(ctx_id=ctx_id, det_size=(640, 640)) - print("InsightFace人脸识别引擎初始化成功。") - _debug_counter["created"] += 1 - print(f"Face engine initialized. Stats: {_debug_counter}") + print("InsightFace引擎初始化完成") + _debug_counter["engine_created"] += 1 + print(f"引擎调试统计: {_debug_counter}") except Exception as e: - print(f"初始化失败: {e}") + print(f"引擎初始化失败: {e}") return False - # 从服务获取所有人脸姓名和特征值 + # 从服务加载已知人脸的姓名和特征值 try: face_data = get_all_face_name_with_eigenvalue() - - # 处理获取到的人脸数据 for person_name, eigenvalue_data in face_data.items(): - # 处理特征值数据 - 兼容数组和字符串两种格式 + # 兼容“numpy数组”和“字符串”格式的特征值 if isinstance(eigenvalue_data, np.ndarray): - # 如果已经是numpy数组、直接使用 eigenvalue = eigenvalue_data.astype(np.float32) elif isinstance(eigenvalue_data, str): - # 清理字符串: 移除方括号、换行符和多余空格 - cleaned = eigenvalue_data.replace('[', '').replace(']', '').replace('\n', '').strip() - # 按空格或逗号分割(处理可能的不同分隔符) - values = [v for v in cleaned.split() if v] - # 转换为数组 + # 清理字符串中的括号、换行等干扰符 + cleaned = eigenvalue_data.replace("[", "").replace("]", "").replace("\n", "").strip() + # 分割并转换为浮点数数组 + values = [v for v in cleaned.split() if v] # 兼容空格/逗号分隔 eigenvalue = np.array(list(map(float, values)), dtype=np.float32) else: - # 不支持的类型 - print(f"Unsupported eigenvalue type for {person_name}") + print(f"不支持的特征值类型({type(eigenvalue_data)}),跳过 {person_name}") continue - # 归一化处理 + # 特征值归一化(保证后续相似度计算的一致性) norm = np.linalg.norm(eigenvalue) if norm != 0: eigenvalue = eigenvalue / norm @@ -210,100 +224,103 @@ def load_model(prefer_gpu=True, preferred_gpus=[0, 1]): _known_faces_embeddings[person_name] = eigenvalue _known_faces_names.append(person_name) + print(f"成功加载 {len(_known_faces_names)} 个人脸的特征库") + except Exception as e: - print(f"Error loading face data from service: {e}") + print(f"加载人脸特征库失败: {e}") - return True if _face_app else False + return _face_app is not None -def detect(frame, threshold=0.4): - """检测并识别人脸、返回结果元组(是否匹配到已知人脸, 结果字符串)""" - global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id - global _ref_count, _last_used_time +def detect(frame, similarity_threshold=0.4): + """ + 检测并识别人脸 + 返回:(是否匹配到已知人脸, 结果描述字符串) + """ + global _face_app, _known_faces_embeddings, _known_faces_names, _ref_count, _last_used_time - # 验证前置条件 + # 校验输入帧有效性 if frame is None or frame.size == 0: - return (False, "无效帧数据") + return (False, "无效的输入帧数据") - # 增加引用计数并获取引擎实例 + # 加锁并更新引用计数、最后使用时间 engine = None with _lock: _ref_count += 1 _last_used_time = time.time() - _debug_counter["detected"] += 1 + _debug_counter["detection_calls"] += 1 - # 初始化引擎(如果未初始化且不在释放中) + # 若引擎未初始化且未在释放中,尝试初始化 if not _face_app and not _is_releasing: if not load_model(prefer_gpu=True): - _ref_count -= 1 # 恢复引用计数 - return (False, "引擎初始化失败") + # 初始化失败,恢复引用计数 + with _lock: + _ref_count = max(0, _ref_count - 1) + return (False, "人脸引擎初始化失败") - # 获取当前引擎引用 - engine = _face_app + engine = _face_app # 获取引擎引用 - # 检查引擎是否可用 - if not engine or not _known_faces_names: + # 校验引擎可用性 + if not engine or len(_known_faces_names) == 0: with _lock: _ref_count = max(0, _ref_count - 1) - return (False, "人脸识别引擎不可用或未初始化") + return (False, "人脸引擎不可用或特征库为空") try: - # 如果使用GPU、确保输入帧在处理前是连续的数组 - if _using_gpu and not frame.flags.contiguous: + # GPU计算时,确保帧数据是连续内存(避免CUDA错误) + if _using_gpu and engine is not None and not frame.flags.contiguous: frame = np.ascontiguousarray(frame) - faces = _face_app.get(frame) + # 执行人脸检测与特征提取 + faces = engine.get(frame) except Exception as e: - print(f"Face detect error: {e}") - # 检测到错误时尝试重新选择GPU并重新初始化 - print("尝试重新选择GPU并重新初始化...") + print(f"人脸检测过程出错: {e}") + # 出错时尝试重新初始化引擎(可能是GPU状态变化导致) + print("尝试重新初始化人脸引擎...") with _lock: _ref_count = max(0, _ref_count - 1) - load_model(prefer_gpu=True) # 重新初始化时保持默认GPU优先级 + load_model(prefer_gpu=True) return (False, f"检测错误: {str(e)}") result_parts = [] - has_matched = False # 标记是否有匹配到的已知人脸 + has_matched_known_face = False # 是否有任意人脸匹配到已知库 for face in faces: - # 特征归一化 - embedding = face.embedding.astype(np.float32) - norm = np.linalg.norm(embedding) + # 归一化当前检测到的人脸特征 + face_embedding = face.embedding.astype(np.float32) + norm = np.linalg.norm(face_embedding) if norm == 0: continue - embedding = embedding / norm + face_embedding = face_embedding / norm - # 对比已知人脸 - max_sim, best_name = -1.0, "Unknown" + # 与已知人脸特征逐一比对 + max_similarity, best_match_name = -1.0, "Unknown" for name in _known_faces_names: known_emb = _known_faces_embeddings[name] - sim = np.dot(embedding, known_emb) - if sim > max_sim: - max_sim = sim - best_name = name + similarity = np.dot(face_embedding, known_emb) # 余弦相似度 + if similarity > max_similarity: + max_similarity = similarity + best_match_name = name - # 判断匹配结果 - is_match = max_sim >= threshold - if is_match: - has_matched = True # 只要有一个匹配成功、就标记为True + # 判断是否匹配成功 + is_matched = max_similarity >= similarity_threshold + if is_matched: + has_matched_known_face = True - bbox = face.bbox + # 记录该人脸的检测结果 + bbox = face.bbox # 人脸边界框 result_parts.append( - f"{'匹配' if is_match else '不匹配'}: {best_name} (相似度: {max_sim:.2f}, 边界框: {bbox})" + f"{'匹配' if is_matched else '未匹配'}: {best_match_name} " + f"(相似度: {max_similarity:.2f}, 边界框: {bbox.astype(int).tolist()})" ) - # 构建结果字符串 - if not result_parts: - result_str = "未检测到人脸" - else: - result_str = "; ".join(result_parts) + # 构建最终结果字符串 + result_str = "未检测到人脸" if not result_parts else "; ".join(result_parts) - # 减少引用计数、确保线程安全 + # 释放引用计数(线程安全) with _lock: _ref_count = max(0, _ref_count - 1) - # 持续使用时更新最后使用时间 - if _ref_count > 0: - _last_used_time = time.time() + # 若仍有引用,更新最后使用时间;若引用为0,也立即标记(加快闲置检测) + _last_used_time = time.time() - # 第一个返回值为: 是否匹配到已知人脸 - return (has_matched, result_str) \ No newline at end of file + return (has_matched_known_face, result_str) \ No newline at end of file diff --git a/core/ocr.py b/core/ocr.py index c4c78b0..4d46ca2 100644 --- a/core/ocr.py +++ b/core/ocr.py @@ -167,7 +167,19 @@ def detect(frame): items_to_process = [line] for item in items_to_process: - # 跳过纯数字列表(可能是坐标信息) + # 精确识别并忽略图片坐标位置信息 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] + if isinstance(item, list) and len(item) == 4: # 四边形有4个顶点 + is_coordinate = True + for point in item: + # 每个顶点应该是包含2个数字的列表 + if not (isinstance(point, list) and len(point) == 2 and + all(isinstance(coord, (int, float)) for coord in point)): + is_coordinate = False + break + if is_coordinate: + continue # 是坐标信息,直接忽略 + + # 跳过纯数字列表(其他可能的坐标形式) if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item): continue diff --git a/main.py b/main.py index ec9df8a..271c78d 100644 --- a/main.py +++ b/main.py @@ -12,7 +12,7 @@ from service.sensitive_service import router as sensitive_router from service.face_service import router as face_router from service.device_service import router as device_router from ws.ws import ws_router, lifespan - +from core.establish import create_directory_structure # ------------------------------ # 初始化 FastAPI 应用、指定生命周期管理 @@ -47,6 +47,8 @@ if __name__ == "__main__": YOLO_MODEL_PATH = r"/core/models\best.pt" OCR_CONFIG_PATH = r"/core/config\config.yaml" + create_directory_structure() + # 初始化项目(默认端口设为8000、避免初始化失败时port未定义) port = int(SERVER_CONFIG.get("port", 8000)) diff --git a/ws/ws.py b/ws/ws.py index bcc86b8..dc570e0 100644 --- a/ws/ws.py +++ b/ws/ws.py @@ -3,12 +3,11 @@ import datetime import json import os from contextlib import asynccontextmanager -from typing import Dict, Optional, AsyncGenerator +from typing import Dict, Optional from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip from service.device_action_service import add_device_action from schema.device_action_schema import DeviceActionCreate -# 【修改1:导入detect和TIMEOUT(用于检测超时控制)】 -from core.all import detect, load_model, TIMEOUT +from core.all import detect, load_model import cv2 import numpy as np @@ -21,7 +20,7 @@ WS_ENDPOINT = "/ws" # WebSocket端点路径 FRAME_QUEUE_SIZE = 1 # 帧队列大小限制 -# 工具函数: 获取格式化时间字符串(统一时间戳格式) +# 工具函数: 获取格式化时间字符串 def get_current_time_str() -> str: return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -40,13 +39,13 @@ class ClientConnection: self.consumer_task: Optional[asyncio.Task] = None def update_heartbeat(self): - """更新心跳时间(客户端发送心跳时调用)""" + """更新心跳时间""" self.last_heartbeat = datetime.datetime.now() def is_alive(self) -> bool: - """判断客户端是否存活(心跳超时检查)""" - timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds() - return timeout < HEARTBEAT_TIMEOUT + """判断客户端是否存活""" + timeout_seconds = (datetime.datetime.now() - self.last_heartbeat).total_seconds() + return timeout_seconds < HEARTBEAT_TIMEOUT def start_consumer(self): """启动帧消费任务""" @@ -54,10 +53,7 @@ class ClientConnection: return self.consumer_task async def send_frame_permit(self): - """ - 发送「帧发送许可信号」 - 通知客户端可发送下一帧图像 - """ + """发送帧发送许可信号""" try: frame_permit_msg = { "type": "frame", @@ -65,26 +61,21 @@ class ClientConnection: "client_ip": self.client_ip } await self.websocket.send_json(frame_permit_msg) - print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送帧发送许可信号(取帧后立即通知)") + print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送帧发送许可信号") except Exception as e: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可信号发送失败 - {str(e)}") async def consume_frames(self) -> None: - """消费队列中的帧并处理(核心调整: 取帧后立即发许可、再处理帧)""" + """消费队列中的帧并处理""" try: while True: - # 1. 从队列取出帧(阻塞直到有帧可用) + # 取出帧并立即发送下一帧许可 frame_data = await self.frame_queue.get() - - # -------------------------- 核心修改: 取出帧后立即发送下一帧许可 -------------------------- - await self.send_frame_permit() # 取帧即通知客户端发下一帧、无需等处理完成 - # ----------------------------------------------------------------------------------------- + await self.send_frame_permit() try: - # 2. 处理取出的帧(即使处理慢、客户端也已收到许可、可提前准备下一帧) await self.process_frame(frame_data) finally: - # 3. 标记帧任务完成(无论处理成功/失败、都需清理队列) self.frame_queue.task_done() except asyncio.CancelledError: @@ -93,8 +84,8 @@ class ClientConnection: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}") async def process_frame(self, frame_data: bytes) -> None: - """处理单帧图像数据""" - # 二进制数据转OpenCV图像 + """处理单帧图像数据(核心修复:按3个返回值解包)""" + # 二进制转OpenCV图像 nparr = np.frombuffer(frame_data, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if img is None: @@ -102,52 +93,41 @@ class ClientConnection: return try: - # -------------------------- 提交检测任务并等待结果 -------------------------- - # 1. 提交检测任务获取Future对象(非阻塞) - detection_future = detect(img) - # 2. 用asyncio.to_thread等待Future结果(避免阻塞asyncio事件循环),设置超时 - try: - # 解包4元素结果:(是否违规, 结果数据, 检测器类型, 任务ID) - has_violation, data, detector_type, task_id = await asyncio.to_thread( - detection_future.result, # 调用Future的result()获取实际结果 - timeout=TIMEOUT # 超时控制(与all.py配置一致) - ) - except TimeoutError: - # 处理检测超时场景 - print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测任务超时(超过{TIMEOUT}秒)") - has_violation = False - data = f"检测超时(超过{TIMEOUT}秒)" - detector_type = "timeout" - task_id = -1 # 超时任务ID标记为-1 - # ----------------------------------------------------------------------------------------- + # -------------------------- 修复核心:匹配detect返回的3个值 -------------------------- + # 假设detect返回 (是否违规, 结果数据, 检测器类型) + has_violation, data, detector_type = await asyncio.to_thread( + detect, # 调用检测函数 + img # 传入图像参数 + ) + # ------------------------------------------------------------------------------------- - # 打印检测结果 + # 打印检测结果(移除task_id相关内容) print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - " - f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}, 任务ID: {task_id}") + f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}") # 处理违规逻辑 if has_violation: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - " f"类型: {detector_type}, 详情: {data}") - # 调用违规次数加一方法 + # 违规次数+1 try: await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip) print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数已+1") except Exception as e: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}") - # 发送「危险通知」 + # 发送危险通知 danger_msg = { "type": "danger", "timestamp": get_current_time_str(), - "client_ip": self.client_ip + "client_ip": self.client_ip, + "detail": data } - - # TODO 数据存储到数据库 await self.websocket.send_json(danger_msg) else: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 未检测到违规") + except Exception as e: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}") @@ -157,7 +137,7 @@ connected_clients: Dict[str, ClientConnection] = {} heartbeat_task: Optional[asyncio.Task] = None -# 心跳检查(定时清理超时客户端 + 调用离线状态更新方法) +# 心跳检查任务 async def heartbeat_checker(): while True: current_time = get_current_time_str() @@ -172,7 +152,7 @@ async def heartbeat_checker(): conn.consumer_task.cancel() await conn.websocket.close(code=1008, reason="心跳超时") - # 超时设为离线并记录 + # 标记离线 try: await asyncio.to_thread(update_online_status_by_ip, ip, 0) action_data = DeviceActionCreate(client_ip=ip, action=0) @@ -247,7 +227,6 @@ ws_router = APIRouter() @ws_router.websocket(WS_ENDPOINT) async def websocket_endpoint(websocket: WebSocket): - # 加载模型(首次连接时自动加载,线程安全) load_model() await websocket.accept() client_ip = websocket.client.host if websocket.client else "unknown_ip" @@ -257,7 +236,7 @@ async def websocket_endpoint(websocket: WebSocket): is_online_updated = False try: - # 处理重复连接(关闭同一IP的旧连接) + # 处理重复连接 if client_ip in connected_clients: old_conn = connected_clients[client_ip] if old_conn.consumer_task and not old_conn.consumer_task.done(): @@ -270,10 +249,9 @@ async def websocket_endpoint(websocket: WebSocket): new_conn = ClientConnection(websocket, client_ip) connected_clients[client_ip] = new_conn new_conn.start_consumer() - # 初始许可: 连接建立后立即发一次、让客户端知道可发第一帧 await new_conn.send_frame_permit() - # 标记上线并记录 + # 标记上线 try: await asyncio.to_thread(update_online_status_by_ip, client_ip, 1) action_data = DeviceActionCreate(client_ip=client_ip, action=1) @@ -285,7 +263,7 @@ async def websocket_endpoint(websocket: WebSocket): print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}") - # 消息循环(接收客户端文本/二进制消息) + # 消息循环 while True: data = await websocket.receive() if "text" in data: @@ -298,13 +276,12 @@ async def websocket_endpoint(websocket: WebSocket): except Exception as e: print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}") finally: - # 清理资源并标记离线 + # 清理资源 if client_ip in connected_clients: conn = connected_clients[client_ip] if conn.consumer_task and not conn.consumer_task.done(): conn.consumer_task.cancel() - # 主动/异常断开时标记离线(仅当上线状态更新成功时) if is_online_updated: try: await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)