diff --git a/core/all.py b/core/all.py index 439b14e..0907e35 100644 --- a/core/all.py +++ b/core/all.py @@ -9,7 +9,7 @@ import numpy as np # -------------------------- 核心配置参数 -------------------------- MAX_WORKERS = 6 # 线程池最大线程数 DETECTION_ORDER = ["yolo", "face", "ocr"] # 检测优先级顺序 -TIMEOUT = 30 # 检测超时时间(秒) +TIMEOUT = 30 # 检测超时时间(秒) 【确保此常量可被外部导入】 # -------------------------- 全局状态管理 -------------------------- _executor = None # 线程池实例 @@ -80,7 +80,7 @@ def shutdown(): # -------------------------- 检测逻辑实现 -------------------------- def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple: - """在子线程中执行检测逻辑""" + """在子线程中执行检测逻辑(返回4元素tuple:是否成功、结果、检测器类型、任务ID)""" thread_name = threading.current_thread().name print(f"任务[{task_id}] 开始执行、线程: {thread_name}") @@ -99,27 +99,27 @@ def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple: print(f"任务[{task_id}] {detector}检测结果: {'成功' if flag else '失败'}") if flag: print(f"任务[{task_id}] 完成检测、使用检测器: {detector}") - return (True, result, detector, task_id) + return (True, result, detector, task_id) # 4元素tuple # 所有检测器均未检测到结果 print(f"任务[{task_id}] 所有检测器均未检测到内容") - return (False, "未检测到任何内容", "none", task_id) + return (False, "未检测到任何内容", "none", task_id) # 4元素tuple except Exception as e: print(f"任务[{task_id}] 检测过程发生错误: {str(e)}") - return (False, f"检测错误: {str(e)}", "error", task_id) + return (False, f"检测错误: {str(e)}", "error", task_id) # 4元素tuple # -------------------------- 外部调用接口 -------------------------- def detect(frame: np.ndarray) -> Future: """ - 提交检测任务到线程池 + 提交检测任务到线程池(返回Future对象,需调用result()获取4元素结果) 参数: frame: 待检测图像(ndarray格式、cv2.imdecode生成) 返回: - Future对象、通过result()方法获取检测结果 + Future对象、result()返回tuple: (has_violation, data, detector_type, task_id) """ # 确保模型已加载 if not _model_loaded: @@ -129,8 +129,7 @@ def detect(frame: np.ndarray) -> Future: # 生成任务ID task_id = _get_next_task_id() - # 提交任务到线程池 + # 提交任务到线程池(返回Future) future = _executor.submit(_detect_in_thread, frame, task_id) print(f"任务[{task_id}]: 已提交到线程池") - return future - + return future \ No newline at end of file diff --git a/ws/ws.py b/ws/ws.py index d5fa7f4..9e237d2 100644 --- a/ws/ws.py +++ b/ws/ws.py @@ -7,12 +7,12 @@ from typing import Dict, Optional, AsyncGenerator from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip from service.device_action_service import add_device_action from schema.device_action_schema import DeviceActionCreate -from core.all import detect +# 【修改1:导入detect和TIMEOUT(用于检测超时控制)】 +from core.all import detect, load_model, TIMEOUT import cv2 import numpy as np from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI -from core.all import load_model # 配置常量 HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒) @@ -93,7 +93,7 @@ class ClientConnection: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}") async def process_frame(self, frame_data: bytes) -> None: - """处理单帧图像数据(检测违规后发送危险通知 + 调用违规次数加一方法)""" + """处理单帧图像数据(【核心修改:等待检测结果+修正解包】)""" # 二进制数据转OpenCV图像 nparr = np.frombuffer(frame_data, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) @@ -109,13 +109,36 @@ class ClientConnection: try: cv2.imwrite(filename, img) print(f"[{get_current_time_str()}] 图像已保存至: {filename}") - has_violation, data, type = detect(img) - print(has_violation) - print(type) - print(data) + + # -------------------------- 【核心修改1:提交检测任务并等待结果】 -------------------------- + # 1. 提交检测任务获取Future对象(非阻塞) + detection_future = detect(img) + # 2. 用asyncio.to_thread等待Future结果(避免阻塞asyncio事件循环),设置超时 + try: + # 解包4元素结果:(是否违规, 结果数据, 检测器类型, 任务ID) + has_violation, data, detector_type, task_id = await asyncio.to_thread( + detection_future.result, # 调用Future的result()获取实际结果 + timeout=TIMEOUT # 超时控制(与all.py配置一致) + ) + except TimeoutError: + # 处理检测超时场景 + print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测任务超时(超过{TIMEOUT}秒)") + has_violation = False + data = f"检测超时(超过{TIMEOUT}秒)" + detector_type = "timeout" + task_id = -1 # 超时任务ID标记为-1 + # ----------------------------------------------------------------------------------------- + + # -------------------------- 【核心修改2:修正日志打印变量名】 -------------------------- + # 打印检测结果(避免使用Python关键字"type") + print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - " + f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}, 任务ID: {task_id}") + # ----------------------------------------------------------------------------------------- + + # 处理违规逻辑(变量名从type改为detector_type) if has_violation: - print( - f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - 类型: {type}, 详情: {data}") + print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - " + f"类型: {detector_type}, 详情: {data}") # 调用违规次数加一方法 try: @@ -128,8 +151,12 @@ class ClientConnection: danger_msg = { "type": "danger", "timestamp": get_current_time_str(), - "client_ip": self.client_ip + "client_ip": self.client_ip, + "detector_type": detector_type, + "detail": str(data) } + + # TODO 数据存储到数据库 await self.websocket.send_json(danger_msg) else: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 未检测到违规") @@ -232,7 +259,7 @@ ws_router = APIRouter() @ws_router.websocket(WS_ENDPOINT) async def websocket_endpoint(websocket: WebSocket): - # 加载模型 + # 加载模型(首次连接时自动加载,线程安全) load_model() await websocket.accept() client_ip = websocket.client.host if websocket.client else "unknown_ip" @@ -242,7 +269,7 @@ async def websocket_endpoint(websocket: WebSocket): is_online_updated = False try: - # 处理重复连接 + # 处理重复连接(关闭同一IP的旧连接) if client_ip in connected_clients: old_conn = connected_clients[client_ip] if old_conn.consumer_task and not old_conn.consumer_task.done(): @@ -255,7 +282,7 @@ async def websocket_endpoint(websocket: WebSocket): new_conn = ClientConnection(websocket, client_ip) connected_clients[client_ip] = new_conn new_conn.start_consumer() - # 初始许可: 连接建立后立即发一次、让客户端知道可发第一帧(后续靠取帧后自动发) + # 初始许可: 连接建立后立即发一次、让客户端知道可发第一帧 await new_conn.send_frame_permit() # 标记上线并记录 @@ -270,7 +297,7 @@ async def websocket_endpoint(websocket: WebSocket): print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}") - # 消息循环 + # 消息循环(接收客户端文本/二进制消息) while True: data = await websocket.receive() if "text" in data: @@ -289,7 +316,7 @@ async def websocket_endpoint(websocket: WebSocket): if conn.consumer_task and not conn.consumer_task.done(): conn.consumer_task.cancel() - # 主动/异常断开时标记离线 + # 主动/异常断开时标记离线(仅当上线状态更新成功时) if is_online_updated: try: await asyncio.to_thread(update_online_status_by_ip, client_ip, 0) @@ -300,4 +327,4 @@ async def websocket_endpoint(websocket: WebSocket): print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后离线更新失败 - {str(e)}") connected_clients.pop(client_ip, None) - print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理、在线数: {len(connected_clients)}") + print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理、在线数: {len(connected_clients)}") \ No newline at end of file