Files
video/ws/ws.py
2025-09-04 22:59:27 +08:00

304 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import datetime
import json
import os
from contextlib import asynccontextmanager
from typing import Dict, Optional, AsyncGenerator
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
from service.device_action_service import add_device_action
from schema.device_action_schema import DeviceActionCreate
from core.all import detect
import cv2
import numpy as np
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
from core.all import load_model
# 配置常量
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
WS_ENDPOINT = "/ws" # WebSocket端点路径
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
# 工具函数:获取格式化时间字符串(统一时间戳格式)
def get_current_time_str() -> str:
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def get_current_time_file_str() -> str:
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
# 客户端连接封装
class ClientConnection:
def __init__(self, websocket: WebSocket, client_ip: str):
self.websocket = websocket
self.client_ip = client_ip
self.last_heartbeat = datetime.datetime.now()
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE)
self.consumer_task: Optional[asyncio.Task] = None
def update_heartbeat(self):
"""更新心跳时间(客户端发送心跳时调用)"""
self.last_heartbeat = datetime.datetime.now()
def is_alive(self) -> bool:
"""判断客户端是否存活(心跳超时检查)"""
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
return timeout < HEARTBEAT_TIMEOUT
def start_consumer(self):
"""启动帧消费任务"""
self.consumer_task = asyncio.create_task(self.consume_frames())
return self.consumer_task
async def send_frame_permit(self):
"""
发送「帧发送许可信号」
通知客户端可发送下一帧图像
"""
try:
frame_permit_msg = {
"type": "frame",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip
}
await self.websocket.send_json(frame_permit_msg)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:已发送帧发送许可信号(取帧后立即通知)")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧许可信号发送失败 - {str(e)}")
async def consume_frames(self) -> None:
"""消费队列中的帧并处理(核心调整:取帧后立即发许可,再处理帧)"""
try:
while True:
# 1. 从队列取出帧(阻塞直到有帧可用)
frame_data = await self.frame_queue.get()
# -------------------------- 核心修改:取出帧后立即发送下一帧许可 --------------------------
await self.send_frame_permit() # 取帧即通知客户端发下一帧,无需等处理完成
# -----------------------------------------------------------------------------------------
try:
# 2. 处理取出的帧(即使处理慢,客户端也已收到许可,可提前准备下一帧)
await self.process_frame(frame_data)
finally:
# 3. 标记帧任务完成(无论处理成功/失败,都需清理队列)
self.frame_queue.task_done()
except asyncio.CancelledError:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费任务已取消")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费逻辑错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧图像数据(检测违规后发送危险通知 + 调用违规次数加一方法)"""
# 二进制数据转OpenCV图像
nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:无法解析图像数据")
return
# 确保图像保存目录存在
os.makedirs('images', exist_ok=True)
# 保存图像按IP+时间戳命名,避免冲突)
filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg"
try:
cv2.imwrite(filename, img)
print(f"[{get_current_time_str()}] 图像已保存至:{filename}")
has_violation, data, type = detect(img)
print(has_violation)
print(type)
print(data)
if has_violation:
print(
f"[{get_current_time_str()}] 客户端{self.client_ip}:检测到违规 - 类型: {type}, 详情: {data}")
# 调用违规次数加一方法
try:
await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数已+1")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数更新失败 - {str(e)}")
# 发送「危险通知」
danger_msg = {
"type": "danger",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip
}
await self.websocket.send_json(danger_msg)
else:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:未检测到违规")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:图像处理错误 - {str(e)}")
# 全局状态管理
connected_clients: Dict[str, ClientConnection] = {}
heartbeat_task: Optional[asyncio.Task] = None
# 心跳检查(定时清理超时客户端 + 调用离线状态更新方法)
async def heartbeat_checker():
while True:
current_time = get_current_time_str()
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
if timeout_ips:
print(f"[{current_time}] 心跳检查:{len(timeout_ips)}个客户端超时IP{timeout_ips}")
for ip in timeout_ips:
try:
conn = connected_clients[ip]
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
await conn.websocket.close(code=1008, reason="心跳超时")
# 超时设为离线并记录
try:
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
action_data = DeviceActionCreate(client_ip=ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{ip}:已标记为离线并记录操作")
except Exception as e:
print(f"[{current_time}] 客户端{ip}:离线状态更新失败 - {str(e)}")
finally:
connected_clients.pop(ip, None)
else:
print(f"[{current_time}] 心跳检查:{len(connected_clients)}个客户端在线")
await asyncio.sleep(HEARTBEAT_INTERVAL)
# 应用生命周期管理
@asynccontextmanager
async def lifespan(app: FastAPI):
global heartbeat_task
heartbeat_task = asyncio.create_task(heartbeat_checker())
print(f"[{get_current_time_str()}] 全局心跳检查任务启动任务ID{id(heartbeat_task)}")
yield
if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel()
try:
await heartbeat_task
print(f"[{get_current_time_str()}] 全局心跳检查任务已取消")
except asyncio.CancelledError:
pass
# 消息处理工具函数
async def send_heartbeat_ack(conn: ClientConnection):
try:
heartbeat_ack_msg = {
"type": "heart",
"timestamp": get_current_time_str(),
"client_ip": conn.client_ip
}
await conn.websocket.send_json(heartbeat_ack_msg)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:已发送心跳确认")
return True
except Exception as e:
connected_clients.pop(conn.client_ip, None)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:心跳确认发送失败 - {str(e)}")
return False
async def handle_text_msg(conn: ClientConnection, text: str):
try:
msg = json.loads(text)
if msg.get("type") == "heart":
conn.update_heartbeat()
await send_heartbeat_ack(conn)
else:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:未知文本消息类型({msg.get('type')}")
except json.JSONDecodeError:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}无效JSON文本消息")
async def handle_binary_msg(conn: ClientConnection, data: bytes):
try:
conn.frame_queue.put_nowait(data)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:图像数据({len(data)}字节)已加入队列")
except asyncio.QueueFull:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:帧队列已满,丢弃当前图像数据")
# WebSocket路由配置
ws_router = APIRouter()
@ws_router.websocket(WS_ENDPOINT)
async def websocket_endpoint(websocket: WebSocket):
# 加载模型
load_model()
await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown_ip"
current_time = get_current_time_str()
print(f"[{current_time}] 客户端{client_ip}WebSocket连接已建立")
is_online_updated = False
try:
# 处理重复连接
if client_ip in connected_clients:
old_conn = connected_clients[client_ip]
if old_conn.consumer_task and not old_conn.consumer_task.done():
old_conn.consumer_task.cancel()
await old_conn.websocket.close(code=1008, reason="同一IP新连接建立")
connected_clients.pop(client_ip)
print(f"[{current_time}] 客户端{client_ip}:已关闭旧连接")
# 注册新连接
new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn
new_conn.start_consumer()
# 初始许可:连接建立后立即发一次,让客户端知道可发第一帧(后续靠取帧后自动发)
await new_conn.send_frame_permit()
# 标记上线并记录
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
action_data = DeviceActionCreate(client_ip=client_ip, action=1)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{client_ip}:已标记为在线并记录操作")
is_online_updated = True
except Exception as e:
print(f"[{current_time}] 客户端{client_ip}:上线状态更新失败 - {str(e)}")
print(f"[{current_time}] 客户端{client_ip}:新连接注册成功,在线数:{len(connected_clients)}")
# 消息循环
while True:
data = await websocket.receive()
if "text" in data:
await handle_text_msg(new_conn, data["text"])
elif "bytes" in data:
await handle_binary_msg(new_conn, data["bytes"])
except WebSocketDisconnect as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}:主动断开连接(代码:{e.code}")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}:连接异常 - {str(e)[:50]}")
finally:
# 清理资源并标记离线
if client_ip in connected_clients:
conn = connected_clients[client_ip]
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
# 主动/异常断开时标记离线
if is_online_updated:
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
action_data = DeviceActionCreate(client_ip=client_ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后已标记为离线")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后离线更新失败 - {str(e)}")
connected_clients.pop(client_ip, None)
print(f"[{get_current_time_str()}] 客户端{client_ip}:资源已清理,在线数:{len(connected_clients)}")