Files
video/ws/ws.py
2025-09-16 20:17:48 +08:00

310 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import datetime
import json
import os
import base64
from contextlib import asynccontextmanager
from typing import Dict, Optional
import cv2
import numpy as np
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
from service.device_action_service import add_device_action
from schema.device_action_schema import DeviceActionCreate
from core.all import detect, load_model
# -------------------------- 1. AES 加密工具(仅用于服务器向客户端发送消息)--------------------------
AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa" # 约定密钥32字节
AES_BLOCK_SIZE = 16 # AES固定块大小
def aes_encrypt(plaintext: str) -> dict:
"""AES-CBC加密返回{密文, IV, 算法标识}均Base64编码- 仅用于服务器发消息"""
try:
iv = os.urandom(AES_BLOCK_SIZE) # 随机IV16字节
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv)
padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE)
ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8")
iv_base64 = base64.b64encode(iv).decode("utf-8")
return {
"ciphertext": ciphertext,
"iv": iv_base64,
"algorithm": "AES-CBC"
}
except Exception as e:
raise Exception(f"AES加密失败: {str(e)}") from e
# -------------------------- 2. 配置常量(保持原有)--------------------------
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
WS_ENDPOINT = "/ws" # WebSocket端点路径
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
# -------------------------- 3. 工具函数(保持原有)--------------------------
def get_current_time_str() -> str:
"""获取格式化时间字符串"""
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def get_current_time_file_str() -> str:
"""获取文件命名用时间字符串"""
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
# -------------------------- 4. 客户端连接封装(服务器发消息仍加密,接收消息改明文)--------------------------
class ClientConnection:
def __init__(self, websocket: WebSocket, client_ip: str):
self.websocket = websocket
self.client_ip = client_ip
self.last_heartbeat = datetime.datetime.now()
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE)
self.consumer_task: Optional[asyncio.Task] = None
def update_heartbeat(self):
"""更新心跳时间"""
self.last_heartbeat = datetime.datetime.now()
def is_alive(self) -> bool:
"""判断客户端是否存活"""
timeout_seconds = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
return timeout_seconds < HEARTBEAT_TIMEOUT
def start_consumer(self):
"""启动帧消费任务"""
self.consumer_task = asyncio.create_task(self.consume_frames())
return self.consumer_task
async def send_frame_permit(self):
"""发送加密的帧许可信号(服务器→客户端:加密)"""
try:
frame_permit_msg = {
"type": "frame",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip
}
encrypted_msg = aes_encrypt(json.dumps(frame_permit_msg)) # 保持加密
await self.websocket.send_json(encrypted_msg)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送加密帧许可")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可加密/发送失败 - {str(e)}")
async def consume_frames(self) -> None:
"""消费队列中的明文图像帧并处理"""
try:
while True:
frame_data = await self.frame_queue.get()
await self.send_frame_permit() # 回复仍加密
try:
await self.process_frame(frame_data)
finally:
self.frame_queue.task_done()
except asyncio.CancelledError:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费任务已取消")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None:
"""处理明文图像帧(危险通知仍加密发送)"""
# 二进制转OpenCV图像客户端发的是明文二进制直接解析
nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析明文图像")
return
try:
has_violation, data, detector_type = await asyncio.to_thread(
detect, self.client_ip, img
)
print(
f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - 违规: {has_violation}, 类型: {detector_type}")
# 违规通知:服务器→客户端,仍加密
if has_violation:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - {data}")
# 违规次数+1
try:
await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数+1")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}")
# 构建危险通知并加密发送
danger_msg = {
"type": "danger",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip,
"detail": data
}
encrypted_danger_msg = aes_encrypt(json.dumps(danger_msg)) # 保持加密
await self.websocket.send_json(encrypted_danger_msg)
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 明文图像处理错误 - {str(e)}")
# -------------------------- 5. 全局状态与心跳管理(保持原有)--------------------------
connected_clients: Dict[str, ClientConnection] = {}
heartbeat_task: Optional[asyncio.Task] = None
async def heartbeat_checker():
"""全局心跳检查任务"""
while True:
current_time = get_current_time_str()
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
if timeout_ips:
print(f"[{current_time}] 心跳检查: {len(timeout_ips)}个客户端超时IP: {timeout_ips}")
for ip in timeout_ips:
try:
conn = connected_clients[ip]
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
await conn.websocket.close(code=1008, reason="心跳超时")
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
action_data = DeviceActionCreate(client_ip=ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{ip}: 已标记为离线")
except Exception as e:
print(f"[{current_time}] 客户端{ip}: 离线处理失败 - {str(e)}")
finally:
connected_clients.pop(ip, None)
else:
print(f"[{current_time}] 心跳检查: {len(connected_clients)}个客户端在线")
await asyncio.sleep(HEARTBEAT_INTERVAL)
# -------------------------- 6. 客户端明文消息处理(关键修改:删除解密逻辑)--------------------------
async def send_heartbeat_ack(conn: ClientConnection):
"""发送加密的心跳确认(服务器→客户端:加密)"""
try:
heartbeat_ack_msg = {
"type": "heart",
"timestamp": get_current_time_str(),
"client_ip": conn.client_ip
}
encrypted_msg = aes_encrypt(json.dumps(heartbeat_ack_msg)) # 保持加密
await conn.websocket.send_json(encrypted_msg)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送加密心跳确认")
return True
except Exception as e:
connected_clients.pop(conn.client_ip, None)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 心跳确认失败 - {str(e)}")
return False
async def handle_text_msg(conn: ClientConnection, text: str):
"""处理客户端明文文本消息(如心跳)- 关键修改无需解密直接解析JSON"""
try:
# 客户端发的是明文JSON直接解析删除原解密步骤
msg = json.loads(text)
if msg.get("type") == "heart":
conn.update_heartbeat()
await send_heartbeat_ack(conn) # 服务器回复仍加密
else:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知明文文本类型({msg.get('type')}")
except json.JSONDecodeError:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON格式明文文本")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 明文文本消息处理失败 - {str(e)}")
# -------------------------- 7. WebSocket路由与生命周期关键修改处理明文二进制图像--------------------------
ws_router = APIRouter()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期:启动/停止心跳任务"""
global heartbeat_task
heartbeat_task = asyncio.create_task(heartbeat_checker())
print(f"[{get_current_time_str()}] 心跳检查任务启动ID: {id(heartbeat_task)}")
yield
if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel()
await heartbeat_task
print(f"[{get_current_time_str()}] 心跳检查任务已取消")
@ws_router.websocket(WS_ENDPOINT)
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket连接处理入口 - 关键修改:接收客户端明文二进制图像"""
load_model() # 加载检测模型(建议移到全局,避免重复加载)
await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown_ip"
current_time = get_current_time_str()
print(f"[{current_time}] 客户端{client_ip}: 连接已建立")
is_online_updated = False
try:
# 处理重复连接(关闭旧连接)
if client_ip in connected_clients:
old_conn = connected_clients[client_ip]
if old_conn.consumer_task and not old_conn.consumer_task.done():
old_conn.consumer_task.cancel()
await old_conn.websocket.close(code=1008, reason="新连接建立")
connected_clients.pop(client_ip)
print(f"[{current_time}] 客户端{client_ip}: 已关闭旧连接")
# 注册新连接
new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn
new_conn.start_consumer()
await new_conn.send_frame_permit() # 首次许可仍加密
# 标记客户端上线
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
action_data = DeviceActionCreate(client_ip=client_ip, action=1)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{client_ip}: 已标记为在线")
is_online_updated = True
except Exception as e:
print(f"[{current_time}] 客户端{client_ip}: 上线状态更新失败 - {str(e)}")
print(f"[{current_time}] 客户端{client_ip}: 连接注册成功,在线数: {len(connected_clients)}")
# 消息循环:接收客户端明文消息(关键修改)
while True:
data = await websocket.receive()
if "text" in data:
# 处理客户端明文文本(如心跳:{"type":"heart",...}
await handle_text_msg(new_conn, data["text"])
elif "bytes" in data:
# 处理客户端明文二进制图像(直接入队,无需解密)
frame_data = data["bytes"]
try:
new_conn.frame_queue.put_nowait(frame_data)
print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像({len(frame_data)}字节)入队")
except asyncio.QueueFull:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 帧队列已满,丢弃数据")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 明文图像处理失败 - {str(e)}")
except WebSocketDisconnect as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code}")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}")
finally:
# 清理资源
if client_ip in connected_clients:
conn = connected_clients[client_ip]
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
if is_online_updated:
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
action_data = DeviceActionCreate(client_ip=client_ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后标记为离线")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 离线更新失败 - {str(e)}")
connected_clients.pop(client_ip, None)
print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理,在线数: {len(connected_clients)}")