import asyncio import datetime import json import os import base64 from contextlib import asynccontextmanager from typing import Dict, Optional import cv2 import numpy as np from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip from service.device_action_service import add_device_action from schema.device_action_schema import DeviceActionCreate from core.all import detect, load_model # -------------------------- 1. AES 加密工具(仅用于服务器向客户端发送消息)-------------------------- AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa" # 约定密钥(32字节) AES_BLOCK_SIZE = 16 # AES固定块大小 def aes_encrypt(plaintext: str) -> dict: """AES-CBC加密:返回{密文, IV, 算法标识}(均Base64编码)- 仅用于服务器发消息""" try: iv = os.urandom(AES_BLOCK_SIZE) # 随机IV(16字节) cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv) padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE) ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8") iv_base64 = base64.b64encode(iv).decode("utf-8") return { "ciphertext": ciphertext, "iv": iv_base64, "algorithm": "AES-CBC" } except Exception as e: raise Exception(f"AES加密失败: {str(e)}") from e # -------------------------- 2. 配置常量(保持原有)-------------------------- HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒) HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒) WS_ENDPOINT = "/ws" # WebSocket端点路径 FRAME_QUEUE_SIZE = 1 # 帧队列大小限制 # -------------------------- 3. 工具函数(保持原有)-------------------------- def get_current_time_str() -> str: """获取格式化时间字符串""" return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") def get_current_time_file_str() -> str: """获取文件命名用时间字符串""" return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") # -------------------------- 4. 客户端连接封装(服务器发消息仍加密,接收消息改明文)-------------------------- class ClientConnection: def __init__(self, websocket: WebSocket, client_ip: str): self.websocket = websocket self.client_ip = client_ip self.last_heartbeat = datetime.datetime.now() self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) self.consumer_task: Optional[asyncio.Task] = None def update_heartbeat(self): """更新心跳时间""" self.last_heartbeat = datetime.datetime.now() def is_alive(self) -> bool: """判断客户端是否存活""" timeout_seconds = (datetime.datetime.now() - self.last_heartbeat).total_seconds() return timeout_seconds < HEARTBEAT_TIMEOUT def start_consumer(self): """启动帧消费任务""" self.consumer_task = asyncio.create_task(self.consume_frames()) return self.consumer_task async def send_frame_permit(self): """发送加密的帧许可信号(服务器→客户端:加密)""" try: frame_permit_msg = { "type": "frame", "timestamp": get_current_time_str(), "client_ip": self.client_ip } encrypted_msg = aes_encrypt(json.dumps(frame_permit_msg)) # 保持加密 await self.websocket.send_json(encrypted_msg) print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送加密帧许可") except Exception as e: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可加密/发送失败 - {str(e)}") async def consume_frames(self) -> None: """消费队列中的明文图像帧并处理""" try: while True: frame_data = await self.frame_queue.get() await self.send_frame_permit() # 回复仍加密 try: await self.process_frame(frame_data) finally: self.frame_queue.task_done() except asyncio.CancelledError: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费任务已取消") except Exception as e: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费错误 - {str(e)}") async def process_frame(self, frame_data: bytes) -> None: """处理明文图像帧(危险通知仍加密发送)""" # 二进制转OpenCV图像(客户端发的是明文二进制,直接解析) nparr = np.frombuffer(frame_data, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if img is None: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析明文图像") return try: has_violation, data, detector_type = await asyncio.to_thread( detect, self.client_ip, img ) print( f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - 违规: {has_violation}, 类型: {detector_type}") # 违规通知:服务器→客户端,仍加密 if has_violation: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - {data}") # 违规次数+1 try: await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip) print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数+1") except Exception as e: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}") # 构建危险通知并加密发送 danger_msg = { "type": "danger", "timestamp": get_current_time_str(), "client_ip": self.client_ip, "detail": data } encrypted_danger_msg = aes_encrypt(json.dumps(danger_msg)) # 保持加密 await self.websocket.send_json(encrypted_danger_msg) except Exception as e: print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 明文图像处理错误 - {str(e)}") # -------------------------- 5. 全局状态与心跳管理(保持原有)-------------------------- connected_clients: Dict[str, ClientConnection] = {} heartbeat_task: Optional[asyncio.Task] = None async def heartbeat_checker(): """全局心跳检查任务""" while True: current_time = get_current_time_str() timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()] if timeout_ips: print(f"[{current_time}] 心跳检查: {len(timeout_ips)}个客户端超时(IP: {timeout_ips})") for ip in timeout_ips: try: conn = connected_clients[ip] if conn.consumer_task and not conn.consumer_task.done(): conn.consumer_task.cancel() await conn.websocket.close(code=1008, reason="心跳超时") await asyncio.to_thread(update_online_status_by_ip, ip, 0) action_data = DeviceActionCreate(client_ip=ip, action=0) await asyncio.to_thread(add_device_action, action_data) print(f"[{current_time}] 客户端{ip}: 已标记为离线") except Exception as e: print(f"[{current_time}] 客户端{ip}: 离线处理失败 - {str(e)}") finally: connected_clients.pop(ip, None) else: print(f"[{current_time}] 心跳检查: {len(connected_clients)}个客户端在线") await asyncio.sleep(HEARTBEAT_INTERVAL) # -------------------------- 6. 客户端明文消息处理(关键修改:删除解密逻辑)-------------------------- async def send_heartbeat_ack(conn: ClientConnection): """发送加密的心跳确认(服务器→客户端:加密)""" try: heartbeat_ack_msg = { "type": "heart", "timestamp": get_current_time_str(), "client_ip": conn.client_ip } encrypted_msg = aes_encrypt(json.dumps(heartbeat_ack_msg)) # 保持加密 await conn.websocket.send_json(encrypted_msg) print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送加密心跳确认") return True except Exception as e: connected_clients.pop(conn.client_ip, None) print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 心跳确认失败 - {str(e)}") return False async def handle_text_msg(conn: ClientConnection, text: str): """处理客户端明文文本消息(如心跳)- 关键修改:无需解密,直接解析JSON""" try: # 客户端发的是明文JSON,直接解析(删除原解密步骤) msg = json.loads(text) if msg.get("type") == "heart": conn.update_heartbeat() await send_heartbeat_ack(conn) # 服务器回复仍加密 else: print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知明文文本类型({msg.get('type')})") except json.JSONDecodeError: print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON格式(明文文本)") except Exception as e: print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 明文文本消息处理失败 - {str(e)}") # -------------------------- 7. WebSocket路由与生命周期(关键修改:处理明文二进制图像)-------------------------- ws_router = APIRouter() @asynccontextmanager async def lifespan(app: FastAPI): """应用生命周期:启动/停止心跳任务""" global heartbeat_task heartbeat_task = asyncio.create_task(heartbeat_checker()) print(f"[{get_current_time_str()}] 心跳检查任务启动(ID: {id(heartbeat_task)})") yield if heartbeat_task and not heartbeat_task.done(): heartbeat_task.cancel() await heartbeat_task print(f"[{get_current_time_str()}] 心跳检查任务已取消") @ws_router.websocket(WS_ENDPOINT) async def websocket_endpoint(websocket: WebSocket): """WebSocket连接处理入口 - 关键修改:接收客户端明文二进制图像""" load_model() # 加载检测模型(建议移到全局,避免重复加载) await websocket.accept() client_ip = websocket.client.host if websocket.client else "unknown_ip" current_time = get_current_time_str() print(f"[{current_time}] 客户端{client_ip}: 连接已建立") is_online_updated = False try: # 处理重复连接(关闭旧连接) if client_ip in connected_clients: old_conn = connected_clients[client_ip] if old_conn.consumer_task and not old_conn.consumer_task.done(): old_conn.consumer_task.cancel() await old_conn.websocket.close(code=1008, reason="新连接建立") connected_clients.pop(client_ip) print(f"[{current_time}] 客户端{client_ip}: 已关闭旧连接") # 注册新连接 new_conn = ClientConnection(websocket, client_ip) connected_clients[client_ip] = new_conn new_conn.start_consumer() await new_conn.send_frame_permit() # 首次许可仍加密 # 标记客户端上线 try: await asyncio.to_thread(update_online_status_by_ip, client_ip, 1) action_data = DeviceActionCreate(client_ip=client_ip, action=1) await asyncio.to_thread(add_device_action, action_data) print(f"[{current_time}] 客户端{client_ip}: 已标记为在线") is_online_updated = True except Exception as e: print(f"[{current_time}] 客户端{client_ip}: 上线状态更新失败 - {str(e)}") print(f"[{current_time}] 客户端{client_ip}: 连接注册成功,在线数: {len(connected_clients)}") # 消息循环:接收客户端明文消息(关键修改) while True: data = await websocket.receive() if "text" in data: # 处理客户端明文文本(如心跳:{"type":"heart",...}) await handle_text_msg(new_conn, data["text"]) elif "bytes" in data: # 处理客户端明文二进制图像(直接入队,无需解密) frame_data = data["bytes"] try: new_conn.frame_queue.put_nowait(frame_data) print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像({len(frame_data)}字节)入队") except asyncio.QueueFull: print(f"[{get_current_time_str()}] 客户端{client_ip}: 帧队列已满,丢弃数据") except Exception as e: print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像处理失败 - {str(e)}") except WebSocketDisconnect as e: print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code})") except Exception as e: print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}") finally: # 清理资源 if client_ip in connected_clients: conn = connected_clients[client_ip] if conn.consumer_task and not conn.consumer_task.done(): conn.consumer_task.cancel() if is_online_updated: try: await asyncio.to_thread(update_online_status_by_ip, client_ip, 0) action_data = DeviceActionCreate(client_ip=client_ip, action=0) await asyncio.to_thread(add_device_action, action_data) print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后标记为离线") except Exception as e: print(f"[{get_current_time_str()}] 客户端{client_ip}: 离线更新失败 - {str(e)}") connected_clients.pop(client_ip, None) print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理,在线数: {len(connected_clients)}")