import asyncio import datetime import json from contextlib import asynccontextmanager from typing import Dict, Optional import numpy as np from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI # -------------------------- 配置常量(简化硬编码) -------------------------- HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒) HEARTBEAT_TIMEOUT = 60 # 客户端超时阈值(秒) WS_ENDPOINT = "/ws" # WebSocket端点路径 # -------------------------- 核心数据结构与全局变量 -------------------------- ws_router = APIRouter() # 客户端连接封装(仅保留核心属性和方法) class ClientConnection: def __init__(self, websocket: WebSocket, client_ip: str): self.websocket = websocket self.client_ip = client_ip self.last_heartbeat = datetime.datetime.now() # 更新心跳时间 def update_heartbeat(self): self.last_heartbeat = datetime.datetime.now() # 检查是否存活(超时返回False) def is_alive(self) -> bool: timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds() return timeout < HEARTBEAT_TIMEOUT # 全局连接管理(IP -> 连接实例) connected_clients: Dict[str, ClientConnection] = {} # 心跳任务(全局引用,用于关闭时清理) heartbeat_task: Optional[asyncio.Task] = None # -------------------------- 心跳检查逻辑(精简日志) -------------------------- async def heartbeat_checker(): while True: now = datetime.datetime.now() # 1. 筛选超时客户端(避免遍历中修改字典) timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()] # 2. 处理超时连接(关闭+移除) if timeout_ips: print(f"[{now:%H:%M:%S}] 心跳检查:{len(timeout_ips)}个客户端超时({timeout_ips})") for ip in timeout_ips: try: await connected_clients[ip].websocket.close(code=1008, reason="心跳超时") finally: connected_clients.pop(ip, None) else: print(f"[{now:%H:%M:%S}] 心跳检查:{len(connected_clients)}个客户端在线,无超时") # 3. 等待下一轮检查 await asyncio.sleep(HEARTBEAT_INTERVAL) # -------------------------- 应用生命周期(简化异常处理) -------------------------- @asynccontextmanager async def lifespan(app: FastAPI): global heartbeat_task # 启动心跳任务 heartbeat_task = asyncio.create_task(heartbeat_checker()) print(f"[{datetime.datetime.now():%H:%M:%S}] 心跳任务启动(ID:{id(heartbeat_task)})") yield # 关闭时取消心跳任务 if heartbeat_task and not heartbeat_task.done(): heartbeat_task.cancel() try: await heartbeat_task print(f"[{datetime.datetime.now():%H:%M:%S}] 心跳任务已取消") except asyncio.CancelledError: pass # -------------------------- 消息处理(合并冗余逻辑) -------------------------- async def send_heartbeat_ack(client_ip: str): """回复心跳确认""" if client_ip not in connected_clients: return False try: ack = { "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "type": "heartbeat" } await connected_clients[client_ip].websocket.send_json(ack) return True except Exception: connected_clients.pop(client_ip, None) return False async def handle_text_msg(client_ip: str, text: str, conn: ClientConnection): """处理文本消息(核心:心跳+JSON解析)""" try: msg = json.loads(text) # 仅处理心跳类型消息 if msg.get("type") == "heartbeat": conn.update_heartbeat() await send_heartbeat_ack(client_ip) else: print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:收到消息:{msg}") except json.JSONDecodeError: print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:无效JSON消息") async def handle_binary_msg(client_ip: str, data: bytes): """处理二进制消息(保留扩展入口)""" print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:收到{len(data)}字节二进制数据") # 将二进制转化为 numpy 数组 data_ndarray = np.frombuffer(data, dtype=np.uint8) # 进行检测 # -------------------------- WebSocket核心端点 -------------------------- @ws_router.websocket(WS_ENDPOINT) async def websocket_endpoint(websocket: WebSocket): # 接受连接 + 获取客户端IP await websocket.accept() client_ip = websocket.client.host if websocket.client else "unknown" now = datetime.datetime.now() print(f"[{now:%H:%M:%S}] 客户端{client_ip}:连接成功") try: # 处理重复连接(关闭旧连接) if client_ip in connected_clients: await connected_clients[client_ip].websocket.close(code=1008, reason="同一IP新连接") connected_clients.pop(client_ip) print(f"[{now:%H:%M:%S}] 客户端{client_ip}:关闭旧连接") # 注册新连接 new_conn = ClientConnection(websocket, client_ip) connected_clients[client_ip] = new_conn print(f"[{now:%H:%M:%S}] 客户端{client_ip}:注册成功,当前在线{len(connected_clients)}个") # 循环接收消息 while True: data = await websocket.receive() if "text" in data: await handle_text_msg(client_ip, data["text"], new_conn) elif "bytes" in data: await handle_binary_msg(client_ip, data["bytes"]) # 异常处理(断开/错误) except WebSocketDisconnect as e: print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:主动断开(代码:{e.code})") except Exception as e: print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接异常({str(e)[:50]})") finally: # 清理连接 connected_clients.pop(client_ip, None) print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接已清理,当前在线{len(connected_clients)}个")