This commit is contained in:
ZZX9599
2025-09-04 17:33:20 +08:00
parent 3ed73bd9eb
commit ec6dbfde90

403
ws/ws.py
View File

@ -4,6 +4,8 @@ import json
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Dict, Optional, AsyncGenerator from typing import Dict, Optional, AsyncGenerator
from concurrent.futures import ThreadPoolExecutor # 新增:显式线程池
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
from service.device_action_service import add_device_action from service.device_action_service import add_device_action
from schema.device_action_schema import DeviceActionCreate from schema.device_action_schema import DeviceActionCreate
@ -11,372 +13,305 @@ from schema.device_action_schema import DeviceActionCreate
import cv2 import cv2
import numpy as np import numpy as np
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
from queue import Queue from queue import Queue # 线程安全队列无需额外Lock
from threading import Lock
from ocr.model_violation_detector import MultiModelViolationDetector from ocr.model_violation_detector import MultiModelViolationDetector
# 配置文件路径(建议实际部署时改为相对路径或环境变量) # -------------------------- 配置调整 --------------------------
# 模型路径(建议改为环境变量)
YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt" YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml" OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"
# 模型池配置根据GPU显存调整每个模型约占1G显存) # 核心优化:模型池大小(决定最大并发任务数,显存占用=大小×单模型显存)
MODEL_POOL_SIZE = 3 # 最大并发客户端数 MODEL_POOL_SIZE = 5 # 示例设为5支持5个任务并行显存会明显上升
THREAD_POOL_SIZE = MODEL_POOL_SIZE * 2 # 线程池大小≥模型池,避免线程瓶颈
# 配置常量 # 其他配置
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒) HEARTBEAT_INTERVAL = 30 # 心跳间隔(秒)
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒) HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
WS_ENDPOINT = "/ws" # WebSocket端点路径 WS_ENDPOINT = "/ws" # WebSocket端点
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制 FRAME_QUEUE_SIZE = 5 # 增大帧队列,允许缓存更多帧(避免丢帧)
# -------------------------- 工具函数 --------------------------
# 工具函数:获取格式化时间字符串
def get_current_time_str() -> str: def get_current_time_str() -> str:
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def get_current_time_file_str() -> str: def get_current_time_file_str() -> str:
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
# -------------------------- 模型池重构核心修改1 --------------------------
# 模型池实现 - 提前初始化固定数量的模型实例
class ModelPool: class ModelPool:
def __init__(self, pool_size: int = MODEL_POOL_SIZE): def __init__(self, pool_size: int = MODEL_POOL_SIZE):
self.pool = Queue(maxsize=pool_size) self.pool = Queue(maxsize=pool_size)
self.lock = Lock() # 移除冗余LockQueue.get()/put()本身线程安全
# 提前初始化模型实例(显存会在此阶段预分配) self._init_models(pool_size)
print(f"[{get_current_time_str()}] 模型池初始化完成(共{pool_size}个实例,显存已预分配)")
def _init_models(self, pool_size: int):
"""预加载所有模型实例(初始化时显存会一次性上升)"""
for i in range(pool_size): for i in range(pool_size):
detector = MultiModelViolationDetector( try:
ocr_config_path=OCR_CONFIG_PATH, detector = MultiModelViolationDetector(
yolo_model_path=YOLO_MODEL_PATH, ocr_config_path=OCR_CONFIG_PATH,
ocr_confidence_threshold=0.5 yolo_model_path=YOLO_MODEL_PATH,
) ocr_confidence_threshold=0.5
self.pool.put(detector) )
print(f"[{get_current_time_str()}] 模型池初始化:第{i + 1}/{pool_size}个模型加载完成") self.pool.put(detector)
print(f"[{get_current_time_str()}] 模型实例{i+1}/{pool_size}加载完成")
except Exception as e:
raise RuntimeError(f"模型实例{i+1}加载失败:{str(e)}")
def get_model(self) -> MultiModelViolationDetector: def get_model(self) -> MultiModelViolationDetector:
"""从池子里获取模型(阻塞直到有可用实例""" """获取模型(阻塞直到有空闲实例,确保并发安全"""
with self.lock: return self.pool.get()
return self.pool.get()
def return_model(self, detector: MultiModelViolationDetector): def return_model(self, detector: MultiModelViolationDetector):
"""将模型归还给池子""" """归还模型(立即释放资源供其他任务使用)"""
with self.lock: self.pool.put(detector)
self.pool.put(detector)
# -------------------------- 全局资源初始化 --------------------------
model_pool = ModelPool(pool_size=MODEL_POOL_SIZE) # 初始化模型池(预占显存)
thread_pool = ThreadPoolExecutor( # 显式创建线程池核心修改2
max_workers=THREAD_POOL_SIZE,
thread_name_prefix="ModelWorker-" # 线程命名,便于调试
)
# 初始化模型池(程序启动时加载所有模型,显存会一次性占用 MODEL_POOL_SIZE * 单模型显存) # -------------------------- 客户端连接封装核心修改3 --------------------------
model_pool = ModelPool(pool_size=MODEL_POOL_SIZE)
# 客户端连接封装
class ClientConnection: class ClientConnection:
def __init__(self, websocket: WebSocket, client_ip: str): def __init__(self, websocket: WebSocket, client_ip: str):
self.websocket = websocket self.websocket = websocket
self.client_ip = client_ip self.client_ip = client_ip
self.last_heartbeat = datetime.datetime.now() self.last_heartbeat = datetime.datetime.now()
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) # 增大队列
self.consumer_task: Optional[asyncio.Task] = None self.consumer_task: Optional[asyncio.Task] = None
# 移除“客户端独占模型”不再持有detector属性
# 从模型池获取专属模型(每个客户端独立占用一个模型实例)
self.detector = model_pool.get_model()
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:已获取模型池中的模型实例(显存独立)")
def update_heartbeat(self): def update_heartbeat(self):
"""更新心跳时间"""
self.last_heartbeat = datetime.datetime.now() self.last_heartbeat = datetime.datetime.now()
def is_alive(self) -> bool: def is_alive(self) -> bool:
"""判断客户端是否存活"""
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds() timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
return timeout < HEARTBEAT_TIMEOUT return timeout < HEARTBEAT_TIMEOUT
def start_consumer(self): def start_consumer(self):
"""启动帧消费任务""" """启动帧消费任务(每个客户端一个独立任务)"""
self.consumer_task = asyncio.create_task(self.consume_frames()) self.consumer_task = asyncio.create_task(self.consume_frames())
return self.consumer_task return self.consumer_task
def release_model(self):
"""客户端断开时归还模型到池"""
model_pool.return_model(self.detector)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:模型已归还至模型池(显存可复用)")
async def send_frame_permit(self): async def send_frame_permit(self):
"""发送帧发送许可信号""" """发送帧许可信号(允许客户端继续发帧)"""
try: try:
frame_permit_msg = { await self.websocket.send_json({
"type": "frame", "type": "frame",
"timestamp": get_current_time_str(), "timestamp": get_current_time_str(),
"client_ip": self.client_ip "client_ip": self.client_ip
} })
await self.websocket.send_json(frame_permit_msg)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:已发送帧发送许可信号")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧许可信号发送失败 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧许可发送失败 - {str(e)}")
async def consume_frames(self) -> None: async def consume_frames(self) -> None:
"""消费队列中的帧并处理(并行执行核心""" """消费队列(并发核心:每帧临时借模型处理"""
try: try:
while True: while True:
# 1. 从队列取 # 1. 从队列取帧(无帧时阻塞)
frame_data = await self.frame_queue.get() frame_data = await self.frame_queue.get()
# 2. 立即发送下一帧许可(让客户端持续发帧,积累并发任务)
# 2. 立即发送下一帧许可
await self.send_frame_permit() await self.send_frame_permit()
try: try:
# 3. 并行处理帧用线程池执行AI检测真正并发 # 3. 并行处理帧(核心:任务级借模型
await self.process_frame(frame_data) await self.process_frame(frame_data)
finally: finally:
self.frame_queue.task_done() self.frame_queue.task_done() # 标记帧处理完成
except asyncio.CancelledError: except asyncio.CancelledError:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费任务已取消") print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费任务已取消")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}消费逻辑错误 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}:消费逻辑错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None: async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧图像数据(使用客户端专属模型)""" """处理单帧核心修改4任务级借还模型)"""
# 二进制数据转OpenCV图像 # 1. 临时借用模型(阻塞直到有空闲实例,显存随借用数上升)
nparr = np.frombuffer(frame_data, np.uint8) detector = model_pool.get_model()
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:无法解析图像数据")
return
# 确保图像保存目录存在
os.makedirs('images', exist_ok=True)
# 保存图像
filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg"
try: try:
cv2.imwrite(filename, img) # 2. 二进制转OpenCV图像
print(f"[{get_current_time_str()}] 图像已保存至:{filename}") nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:图像解析失败")
return
# 关键修改:使用客户端专属模型 + 线程池并行执行AI检测 # 3. 保存图像(可选)
has_violation, violation_type, details = await asyncio.to_thread( os.makedirs('images', exist_ok=True)
self.detector.detect_violations, # 客户端独立模型 filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg"
cv2.imwrite(filename, img)
# 4. 显式线程池执行AI检测真正并发无线程瓶颈
loop = asyncio.get_running_loop()
has_violation, violation_type, details = await loop.run_in_executor(
thread_pool, # 用自定义线程池,避免默认线程不足
detector.detect_violations, # 临时借用的模型
img # 输入图像 img # 输入图像
) )
# 5. 违规处理(与原逻辑一致)
if has_violation: if has_violation:
print( print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规 - {violation_type}")
f"[{get_current_time_str()}] 客户端{self.client_ip}:检测到违规 - 类型: {violation_type}, 详情: {details}") # 违规次数更新(用线程池避免阻塞事件循环)
await loop.run_in_executor(thread_pool, increment_alarm_count_by_ip, self.client_ip)
# 调用违规次数加一方法
try:
await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数已+1")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数更新失败 - {str(e)}")
# 发送危险通知 # 发送危险通知
danger_msg = { await self.websocket.send_json({
"type": "danger", "type": "danger",
"timestamp": get_current_time_str(), "timestamp": get_current_time_str(),
"client_ip": self.client_ip "client_ip": self.client_ip,
} "violation_type": violation_type,
await self.websocket.send_json(danger_msg) "details": details
})
else: else:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}未检测到违规") print(f"[{get_current_time_str()}] 客户端{self.client_ip}违规")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}图像处理错误 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}处理错误 - {str(e)}")
finally:
# 6. 无论成功/失败,强制归还模型(核心:释放资源供其他任务使用)
model_pool.return_model(detector)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:模型已归还(可复用)")
# -------------------------- 全局状态与心跳 --------------------------
# 全局状态管理
connected_clients: Dict[str, ClientConnection] = {} connected_clients: Dict[str, ClientConnection] = {}
client_lock = asyncio.Lock() # 保护connected_clients的 client_lock = asyncio.Lock() # 保护客户端字典的异步
heartbeat_task: Optional[asyncio.Task] = None heartbeat_task: Optional[asyncio.Task] = None
# 心跳检查
async def heartbeat_checker(): async def heartbeat_checker():
"""心跳检查(移除模型归还逻辑,因模型已任务级归还)"""
while True: while True:
current_time = get_current_time_str() current_time = get_current_time_str()
# 加锁保护字典遍历
async with client_lock: async with client_lock:
# 筛选超时客户端
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()] timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
if timeout_ips: for ip in timeout_ips:
print(f"[{current_time}] 心跳检查:{len(timeout_ips)}个客户端超时IP{timeout_ips}")
for ip in timeout_ips:
try:
async with client_lock:
conn = connected_clients.get(ip)
if not conn:
continue
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
await conn.websocket.close(code=1008, reason="心跳超时")
# 归还模型
conn.release_model()
# 超时设为离线并记录
try:
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
action_data = DeviceActionCreate(client_ip=ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{ip}:已标记为离线并记录操作")
except Exception as e:
print(f"[{current_time}] 客户端{ip}:离线状态更新失败 - {str(e)}")
finally:
async with client_lock:
connected_clients.pop(ip, None)
else:
async with client_lock: async with client_lock:
print(f"[{current_time}] 心跳检查:{len(connected_clients)}个客户端在线") conn = connected_clients.get(ip)
if not conn:
continue
# 取消消费任务+关闭连接
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
await conn.websocket.close(code=1008, reason="心跳超时")
# 标记离线(用线程池)
loop = asyncio.get_running_loop()
await loop.run_in_executor(thread_pool, update_online_status_by_ip, ip, 0)
await loop.run_in_executor(
thread_pool, add_device_action, DeviceActionCreate(client_ip=ip, action=0)
)
connected_clients.pop(ip)
print(f"[{current_time}] 客户端{ip}:超时离线(资源已清理)")
# 打印在线状态
async with client_lock:
print(f"[{current_time}] 心跳检查:{len(connected_clients)}个客户端在线")
await asyncio.sleep(HEARTBEAT_INTERVAL) await asyncio.sleep(HEARTBEAT_INTERVAL)
# -------------------------- 应用生命周期核心修改5管理线程池 --------------------------
# 应用生命周期管理
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
global heartbeat_task global heartbeat_task
# 启动心跳任务
heartbeat_task = asyncio.create_task(heartbeat_checker()) heartbeat_task = asyncio.create_task(heartbeat_checker())
print(f"[{get_current_time_str()}] 全局心跳检查任务启动(任务ID{id(heartbeat_task)}") print(f"[{get_current_time_str()}] 心跳任务启动ID{id(heartbeat_task)}")
yield print(f"[{get_current_time_str()}] 线程池启动(最大线程数:{THREAD_POOL_SIZE}")
yield # 应用运行期间
# 清理资源
if heartbeat_task and not heartbeat_task.done(): if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel() heartbeat_task.cancel()
try: await heartbeat_task
await heartbeat_task print(f"[{get_current_time_str()}] 心跳任务已关闭")
print(f"[{get_current_time_str()}] 全局心跳检查任务已取消") # 关闭线程池(等待所有任务完成)
except asyncio.CancelledError: thread_pool.shutdown(wait=True)
pass print(f"[{get_current_time_str()}] 线程池已关闭")
# -------------------------- WebSocket路由 --------------------------
# 消息处理工具函数
async def send_heartbeat_ack(conn: ClientConnection):
try:
heartbeat_ack_msg = {
"type": "heart",
"timestamp": get_current_time_str(),
"client_ip": conn.client_ip
}
await conn.websocket.send_json(heartbeat_ack_msg)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:已发送心跳确认")
return True
except Exception as e:
async with client_lock:
connected_clients.pop(conn.client_ip, None)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:心跳确认发送失败 - {str(e)}")
return False
async def handle_text_msg(conn: ClientConnection, text: str):
try:
msg = json.loads(text)
if msg.get("type") == "heart":
conn.update_heartbeat()
await send_heartbeat_ack(conn)
else:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:未知文本消息类型({msg.get('type')}")
except json.JSONDecodeError:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}无效JSON文本消息")
async def handle_binary_msg(conn: ClientConnection, data: bytes):
try:
conn.frame_queue.put_nowait(data)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:图像数据({len(data)}字节)已加入队列")
except asyncio.QueueFull:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:帧队列已满,丢弃当前图像数据")
# WebSocket路由配置
ws_router = APIRouter() ws_router = APIRouter()
@ws_router.websocket(WS_ENDPOINT) @ws_router.websocket(WS_ENDPOINT)
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
await websocket.accept() await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown_ip" client_ip = websocket.client.host if websocket.client else "unknown_ip"
current_time = get_current_time_str() current_time = get_current_time_str()
print(f"[{current_time}] 客户端{client_ip}WebSocket连接建立") print(f"[{current_time}] 客户端{client_ip}:连接建立")
is_online_updated = False
new_conn = None new_conn = None
is_online_updated = False
try: try:
# 处理重复连接 # 处理重复连接(关闭旧连接)
async with client_lock: async with client_lock:
if client_ip in connected_clients: if client_ip in connected_clients:
old_conn = connected_clients[client_ip] old_conn = connected_clients[client_ip]
if old_conn.consumer_task and not old_conn.consumer_task.done(): if old_conn.consumer_task and not old_conn.consumer_task.done():
old_conn.consumer_task.cancel() old_conn.consumer_task.cancel()
await old_conn.websocket.close(code=1008, reason="同一IP新连接建立") await old_conn.websocket.close(code=1008, reason="新连接抢占")
old_conn.release_model() # 归还旧连接的模型
connected_clients.pop(client_ip) connected_clients.pop(client_ip)
print(f"[{current_time}] 客户端{client_ip}:已关闭旧连接并回收模型") print(f"[{current_time}] 客户端{client_ip}旧连接已关闭")
# 注册新连接 # 创建新连接+启动消费任务
new_conn = ClientConnection(websocket, client_ip) new_conn = ClientConnection(websocket, client_ip)
async with client_lock:
connected_clients[client_ip] = new_conn
new_conn.start_consumer() new_conn.start_consumer()
# 初始许可:连接建立后立即发一次 # 初始发送帧许可(让客户端立即发帧)
await new_conn.send_frame_permit() await new_conn.send_frame_permit()
# 标记上线并记录 # 标记客户端在线
try: loop = asyncio.get_running_loop()
await asyncio.to_thread(update_online_status_by_ip, client_ip, 1) await loop.run_in_executor(thread_pool, update_online_status_by_ip, client_ip, 1)
action_data = DeviceActionCreate(client_ip=client_ip, action=1) await loop.run_in_executor(
await asyncio.to_thread(add_device_action, action_data) thread_pool, add_device_action, DeviceActionCreate(client_ip=client_ip, action=1)
print(f"[{current_time}] 客户端{client_ip}:已标记为在线并记录操作") )
is_online_updated = True is_online_updated = True
except Exception as e:
print(f"[{current_time}] 客户端{client_ip}:上线状态更新失败 - {str(e)}")
async with client_lock: async with client_lock:
print(f"[{current_time}] 客户端{client_ip}:新连接注册成功,在线数:{len(connected_clients)}") connected_clients[client_ip] = new_conn
print(f"[{current_time}] 客户端{client_ip}:注册成功(在线数:{len(connected_clients)}")
# 消息循环 # 消息循环(接收文本/二进制帧)
while True: while True:
data = await websocket.receive() data = await websocket.receive()
if "text" in data: if "text" in data:
await handle_text_msg(new_conn, data["text"]) # 处理文本消息(如心跳)
try:
msg = json.loads(data["text"])
if msg.get("type") == "heart":
new_conn.update_heartbeat()
# 回复心跳确认
await websocket.send_json({
"type": "heart",
"timestamp": get_current_time_str(),
"client_ip": client_ip
})
except json.JSONDecodeError:
print(f"[{get_current_time_str()}] 客户端{client_ip}无效JSON")
elif "bytes" in data: elif "bytes" in data:
await handle_binary_msg(new_conn, data["bytes"]) # 处理二进制帧(图像)
try:
await new_conn.frame_queue.put(data["bytes"])
print(f"[{get_current_time_str()}] 客户端{client_ip}:帧已入队(队列大小:{new_conn.frame_queue.qsize()}")
except asyncio.QueueFull:
print(f"[{get_current_time_str()}] 客户端{client_ip}:帧队列满(丢弃当前帧)")
except WebSocketDisconnect as e: except WebSocketDisconnect as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}:主动断开连接(代码:{e.code}") print(f"[{get_current_time_str()}] 客户端{client_ip}:主动断开(代码:{e.code}")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}:连接异常 - {str(e)[:50]}") print(f"[{get_current_time_str()}] 客户端{client_ip}:连接异常 - {str(e)[:50]}")
finally: finally:
# 清理资源并标记离线 # 清理资源无需归还模型已在process_frame中归还
if new_conn and client_ip in connected_clients: if new_conn and client_ip in connected_clients:
async with client_lock: async with client_lock:
conn = connected_clients.get(client_ip) conn = connected_clients.get(client_ip)
if conn: if conn:
if conn.consumer_task and not conn.consumer_task.done(): if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel() conn.consumer_task.cancel()
# 标记离线(仅当在线状态已更新时)
# 归还模型到模型池
conn.release_model()
# 主动/异常断开时标记离线
if is_online_updated: if is_online_updated:
try: loop = asyncio.get_running_loop()
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0) await loop.run_in_executor(thread_pool, update_online_status_by_ip, client_ip, 0)
action_data = DeviceActionCreate(client_ip=client_ip, action=0) await loop.run_in_executor(
await asyncio.to_thread(add_device_action, action_data) thread_pool, add_device_action, DeviceActionCreate(client_ip=client_ip, action=0)
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后已标记为离线") )
except Exception as e: connected_clients.pop(client_ip)
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后离线更新失败 - {str(e)}")
connected_clients.pop(client_ip, None)
async with client_lock: async with client_lock:
print(f"[{get_current_time_str()}] 客户端{client_ip}:资源清理在线数:{len(connected_clients)}") print(f"[{get_current_time_str()}] 客户端{client_ip}:资源清理完成(在线数:{len(connected_clients)}")
# 创建FastAPI应用
app = FastAPI(lifespan=lifespan)
app.include_router(ws_router)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)