识别结果保存到对应目录下
This commit is contained in:
		
							
								
								
									
										91
									
								
								ws/ws.py
									
									
									
									
									
								
							
							
						
						
									
										91
									
								
								ws/ws.py
									
									
									
									
									
								
							| @ -3,12 +3,11 @@ import datetime | ||||
| import json | ||||
| import os | ||||
| from contextlib import asynccontextmanager | ||||
| from typing import Dict, Optional, AsyncGenerator | ||||
| from typing import Dict, Optional | ||||
| from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip | ||||
| from service.device_action_service import add_device_action | ||||
| from schema.device_action_schema import DeviceActionCreate | ||||
| # 【修改1:导入detect和TIMEOUT(用于检测超时控制)】 | ||||
| from core.all import detect, load_model, TIMEOUT | ||||
| from core.all import detect, load_model | ||||
|  | ||||
| import cv2 | ||||
| import numpy as np | ||||
| @ -21,7 +20,7 @@ WS_ENDPOINT = "/ws"  # WebSocket端点路径 | ||||
| FRAME_QUEUE_SIZE = 1  # 帧队列大小限制 | ||||
|  | ||||
|  | ||||
| # 工具函数: 获取格式化时间字符串(统一时间戳格式) | ||||
| # 工具函数: 获取格式化时间字符串 | ||||
| def get_current_time_str() -> str: | ||||
|     return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") | ||||
|  | ||||
| @ -40,13 +39,13 @@ class ClientConnection: | ||||
|         self.consumer_task: Optional[asyncio.Task] = None | ||||
|  | ||||
|     def update_heartbeat(self): | ||||
|         """更新心跳时间(客户端发送心跳时调用)""" | ||||
|         """更新心跳时间""" | ||||
|         self.last_heartbeat = datetime.datetime.now() | ||||
|  | ||||
|     def is_alive(self) -> bool: | ||||
|         """判断客户端是否存活(心跳超时检查)""" | ||||
|         timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds() | ||||
|         return timeout < HEARTBEAT_TIMEOUT | ||||
|         """判断客户端是否存活""" | ||||
|         timeout_seconds = (datetime.datetime.now() - self.last_heartbeat).total_seconds() | ||||
|         return timeout_seconds < HEARTBEAT_TIMEOUT | ||||
|  | ||||
|     def start_consumer(self): | ||||
|         """启动帧消费任务""" | ||||
| @ -54,10 +53,7 @@ class ClientConnection: | ||||
|         return self.consumer_task | ||||
|  | ||||
|     async def send_frame_permit(self): | ||||
|         """ | ||||
|         发送「帧发送许可信号」 | ||||
|         通知客户端可发送下一帧图像 | ||||
|         """ | ||||
|         """发送帧发送许可信号""" | ||||
|         try: | ||||
|             frame_permit_msg = { | ||||
|                 "type": "frame", | ||||
| @ -65,26 +61,21 @@ class ClientConnection: | ||||
|                 "client_ip": self.client_ip | ||||
|             } | ||||
|             await self.websocket.send_json(frame_permit_msg) | ||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送帧发送许可信号(取帧后立即通知)") | ||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送帧发送许可信号") | ||||
|         except Exception as e: | ||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可信号发送失败 - {str(e)}") | ||||
|  | ||||
|     async def consume_frames(self) -> None: | ||||
|         """消费队列中的帧并处理(核心调整: 取帧后立即发许可、再处理帧)""" | ||||
|         """消费队列中的帧并处理""" | ||||
|         try: | ||||
|             while True: | ||||
|                 # 1. 从队列取出帧(阻塞直到有帧可用) | ||||
|                 # 取出帧并立即发送下一帧许可 | ||||
|                 frame_data = await self.frame_queue.get() | ||||
|  | ||||
|                 # -------------------------- 核心修改: 取出帧后立即发送下一帧许可 -------------------------- | ||||
|                 await self.send_frame_permit()  # 取帧即通知客户端发下一帧、无需等处理完成 | ||||
|                 # ----------------------------------------------------------------------------------------- | ||||
|                 await self.send_frame_permit() | ||||
|  | ||||
|                 try: | ||||
|                     # 2. 处理取出的帧(即使处理慢、客户端也已收到许可、可提前准备下一帧) | ||||
|                     await self.process_frame(frame_data) | ||||
|                 finally: | ||||
|                     # 3. 标记帧任务完成(无论处理成功/失败、都需清理队列) | ||||
|                     self.frame_queue.task_done() | ||||
|  | ||||
|         except asyncio.CancelledError: | ||||
| @ -93,8 +84,8 @@ class ClientConnection: | ||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}") | ||||
|  | ||||
|     async def process_frame(self, frame_data: bytes) -> None: | ||||
|         """处理单帧图像数据""" | ||||
|         # 二进制数据转OpenCV图像 | ||||
|         """处理单帧图像数据(核心修复:按3个返回值解包)""" | ||||
|         # 二进制转OpenCV图像 | ||||
|         nparr = np.frombuffer(frame_data, np.uint8) | ||||
|         img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | ||||
|         if img is None: | ||||
| @ -102,52 +93,41 @@ class ClientConnection: | ||||
|             return | ||||
|  | ||||
|         try: | ||||
|             # -------------------------- 提交检测任务并等待结果 -------------------------- | ||||
|             # 1. 提交检测任务获取Future对象(非阻塞) | ||||
|             detection_future = detect(img) | ||||
|             # 2. 用asyncio.to_thread等待Future结果(避免阻塞asyncio事件循环),设置超时 | ||||
|             try: | ||||
|                 # 解包4元素结果:(是否违规, 结果数据, 检测器类型, 任务ID) | ||||
|                 has_violation, data, detector_type, task_id = await asyncio.to_thread( | ||||
|                     detection_future.result,  # 调用Future的result()获取实际结果 | ||||
|                     timeout=TIMEOUT  # 超时控制(与all.py配置一致) | ||||
|                 ) | ||||
|             except TimeoutError: | ||||
|                 # 处理检测超时场景 | ||||
|                 print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测任务超时(超过{TIMEOUT}秒)") | ||||
|                 has_violation = False | ||||
|                 data = f"检测超时(超过{TIMEOUT}秒)" | ||||
|                 detector_type = "timeout" | ||||
|                 task_id = -1  # 超时任务ID标记为-1 | ||||
|             # ----------------------------------------------------------------------------------------- | ||||
|             # -------------------------- 修复核心:匹配detect返回的3个值 -------------------------- | ||||
|             # 假设detect返回 (是否违规, 结果数据, 检测器类型) | ||||
|             has_violation, data, detector_type = await asyncio.to_thread( | ||||
|                 detect,  # 调用检测函数 | ||||
|                 img      # 传入图像参数 | ||||
|             ) | ||||
|             # ------------------------------------------------------------------------------------- | ||||
|  | ||||
|             # 打印检测结果 | ||||
|             # 打印检测结果(移除task_id相关内容) | ||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - " | ||||
|                   f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}, 任务ID: {task_id}") | ||||
|                   f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}") | ||||
|  | ||||
|             # 处理违规逻辑 | ||||
|             if has_violation: | ||||
|                 print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - " | ||||
|                       f"类型: {detector_type}, 详情: {data}") | ||||
|  | ||||
|                 # 调用违规次数加一方法 | ||||
|                 # 违规次数+1 | ||||
|                 try: | ||||
|                     await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip) | ||||
|                     print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数已+1") | ||||
|                 except Exception as e: | ||||
|                     print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}") | ||||
|  | ||||
|                 # 发送「危险通知」 | ||||
|                 # 发送危险通知 | ||||
|                 danger_msg = { | ||||
|                     "type": "danger", | ||||
|                     "timestamp": get_current_time_str(), | ||||
|                     "client_ip": self.client_ip | ||||
|                     "client_ip": self.client_ip, | ||||
|                     "detail": data | ||||
|                 } | ||||
|  | ||||
|                 # TODO 数据存储到数据库 | ||||
|                 await self.websocket.send_json(danger_msg) | ||||
|             else: | ||||
|                 print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 未检测到违规") | ||||
|  | ||||
|         except Exception as e: | ||||
|             print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}") | ||||
|  | ||||
| @ -157,7 +137,7 @@ connected_clients: Dict[str, ClientConnection] = {} | ||||
| heartbeat_task: Optional[asyncio.Task] = None | ||||
|  | ||||
|  | ||||
| # 心跳检查(定时清理超时客户端 + 调用离线状态更新方法) | ||||
| # 心跳检查任务 | ||||
| async def heartbeat_checker(): | ||||
|     while True: | ||||
|         current_time = get_current_time_str() | ||||
| @ -172,7 +152,7 @@ async def heartbeat_checker(): | ||||
|                         conn.consumer_task.cancel() | ||||
|                     await conn.websocket.close(code=1008, reason="心跳超时") | ||||
|  | ||||
|                     # 超时设为离线并记录 | ||||
|                     # 标记离线 | ||||
|                     try: | ||||
|                         await asyncio.to_thread(update_online_status_by_ip, ip, 0) | ||||
|                         action_data = DeviceActionCreate(client_ip=ip, action=0) | ||||
| @ -247,7 +227,6 @@ ws_router = APIRouter() | ||||
|  | ||||
| @ws_router.websocket(WS_ENDPOINT) | ||||
| async def websocket_endpoint(websocket: WebSocket): | ||||
|     # 加载模型(首次连接时自动加载,线程安全) | ||||
|     load_model() | ||||
|     await websocket.accept() | ||||
|     client_ip = websocket.client.host if websocket.client else "unknown_ip" | ||||
| @ -257,7 +236,7 @@ async def websocket_endpoint(websocket: WebSocket): | ||||
|     is_online_updated = False | ||||
|  | ||||
|     try: | ||||
|         # 处理重复连接(关闭同一IP的旧连接) | ||||
|         # 处理重复连接 | ||||
|         if client_ip in connected_clients: | ||||
|             old_conn = connected_clients[client_ip] | ||||
|             if old_conn.consumer_task and not old_conn.consumer_task.done(): | ||||
| @ -270,10 +249,9 @@ async def websocket_endpoint(websocket: WebSocket): | ||||
|         new_conn = ClientConnection(websocket, client_ip) | ||||
|         connected_clients[client_ip] = new_conn | ||||
|         new_conn.start_consumer() | ||||
|         # 初始许可: 连接建立后立即发一次、让客户端知道可发第一帧 | ||||
|         await new_conn.send_frame_permit() | ||||
|  | ||||
|         # 标记上线并记录 | ||||
|         # 标记上线 | ||||
|         try: | ||||
|             await asyncio.to_thread(update_online_status_by_ip, client_ip, 1) | ||||
|             action_data = DeviceActionCreate(client_ip=client_ip, action=1) | ||||
| @ -285,7 +263,7 @@ async def websocket_endpoint(websocket: WebSocket): | ||||
|  | ||||
|         print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}") | ||||
|  | ||||
|         # 消息循环(接收客户端文本/二进制消息) | ||||
|         # 消息循环 | ||||
|         while True: | ||||
|             data = await websocket.receive() | ||||
|             if "text" in data: | ||||
| @ -298,13 +276,12 @@ async def websocket_endpoint(websocket: WebSocket): | ||||
|     except Exception as e: | ||||
|         print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}") | ||||
|     finally: | ||||
|         # 清理资源并标记离线 | ||||
|         # 清理资源 | ||||
|         if client_ip in connected_clients: | ||||
|             conn = connected_clients[client_ip] | ||||
|             if conn.consumer_task and not conn.consumer_task.done(): | ||||
|                 conn.consumer_task.cancel() | ||||
|  | ||||
|             # 主动/异常断开时标记离线(仅当上线状态更新成功时) | ||||
|             if is_online_updated: | ||||
|                 try: | ||||
|                     await asyncio.to_thread(update_online_status_by_ip, client_ip, 0) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user