Files
video/ws/ws.py
2025-09-15 18:51:18 +08:00

356 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import datetime
import json
import os
import base64
from contextlib import asynccontextmanager
from typing import Dict, Optional
import cv2
import numpy as np
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
from service.device_action_service import add_device_action
from schema.device_action_schema import DeviceActionCreate
from core.all import detect, load_model
# -------------------------- 1. AES 加密解密工具(固定密钥)--------------------------
AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa" # 约定密钥32字节
AES_BLOCK_SIZE = 16 # AES固定块大小
def aes_encrypt(plaintext: str) -> dict:
"""AES-CBC加密返回{密文, IV, 算法标识}均Base64编码"""
try:
iv = os.urandom(AES_BLOCK_SIZE) # 随机IV16字节
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv)
# 明文填充+加密+Base64编码
padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE)
ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8")
iv_base64 = base64.b64encode(iv).decode("utf-8")
return {
"ciphertext": ciphertext,
"iv": iv_base64,
"algorithm": "AES-CBC"
}
except Exception as e:
raise Exception(f"AES加密失败: {str(e)}") from e
def aes_decrypt(encrypted_dict: dict) -> str:
"""AES-CBC解密输入加密字典返回原始文本"""
try:
# Base64解码密文和IV
ciphertext = base64.b64decode(encrypted_dict["ciphertext"])
iv = base64.b64decode(encrypted_dict["iv"])
# 解密+去除填充
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv)
decrypted = unpad(cipher.decrypt(ciphertext), AES_BLOCK_SIZE).decode("utf-8")
return decrypted
except Exception as e:
raise Exception(f"AES解密失败: {str(e)}") from e
# -------------------------- 2. 配置常量(保持原有)--------------------------
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
WS_ENDPOINT = "/ws" # WebSocket端点路径
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
# -------------------------- 3. 工具函数(保持原有)--------------------------
def get_current_time_str() -> str:
"""获取格式化时间字符串"""
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def get_current_time_file_str() -> str:
"""获取文件命名用时间字符串"""
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
# -------------------------- 4. 客户端连接封装(新增消息加密)--------------------------
class ClientConnection:
def __init__(self, websocket: WebSocket, client_ip: str):
self.websocket = websocket
self.client_ip = client_ip
self.last_heartbeat = datetime.datetime.now()
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE)
self.consumer_task: Optional[asyncio.Task] = None
def update_heartbeat(self):
"""更新心跳时间"""
self.last_heartbeat = datetime.datetime.now()
def is_alive(self) -> bool:
"""判断客户端是否存活"""
timeout_seconds = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
return timeout_seconds < HEARTBEAT_TIMEOUT
def start_consumer(self):
"""启动帧消费任务"""
self.consumer_task = asyncio.create_task(self.consume_frames())
return self.consumer_task
async def send_frame_permit(self):
"""发送加密的帧许可信号"""
try:
# 1. 构建原始消息
frame_permit_msg = {
"type": "frame",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip
}
# 2. AES加密消息
encrypted_msg = aes_encrypt(json.dumps(frame_permit_msg))
# 3. 发送加密消息
await self.websocket.send_json(encrypted_msg)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送加密帧许可")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可加密/发送失败 - {str(e)}")
async def consume_frames(self) -> None:
"""消费队列中的帧并处理"""
try:
while True:
frame_data = await self.frame_queue.get()
await self.send_frame_permit() # 发送下一帧许可
try:
await self.process_frame(frame_data)
finally:
self.frame_queue.task_done()
except asyncio.CancelledError:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费任务已取消")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧图像(含加密危险通知)"""
# 二进制转OpenCV图像
nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析图像")
return
try:
# 调用检测函数client_ip + img 双参数)
has_violation, data, detector_type = await asyncio.to_thread(
detect, self.client_ip, img
)
print(
f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - 违规: {has_violation}, 类型: {detector_type}")
# 处理违规逻辑(发送加密危险通知)
if has_violation:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - {data}")
# 违规次数+1
try:
await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数+1")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}")
# 1. 构建原始危险通知
danger_msg = {
"type": "danger",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip,
"detail": data
}
# 2. AES加密通知
encrypted_danger_msg = aes_encrypt(json.dumps(danger_msg))
# 3. 发送加密通知
await self.websocket.send_json(encrypted_danger_msg)
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}")
# -------------------------- 5. 全局状态与心跳管理(保持原有)--------------------------
connected_clients: Dict[str, ClientConnection] = {}
heartbeat_task: Optional[asyncio.Task] = None
async def heartbeat_checker():
"""全局心跳检查任务"""
while True:
current_time = get_current_time_str()
# 筛选超时客户端
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
if timeout_ips:
print(f"[{current_time}] 心跳检查: {len(timeout_ips)}个客户端超时IP: {timeout_ips}")
for ip in timeout_ips:
try:
conn = connected_clients[ip]
# 取消消费任务+关闭连接
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
await conn.websocket.close(code=1008, reason="心跳超时")
# 标记离线
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
action_data = DeviceActionCreate(client_ip=ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{ip}: 已标记为离线")
except Exception as e:
print(f"[{current_time}] 客户端{ip}: 离线处理失败 - {str(e)}")
finally:
connected_clients.pop(ip, None)
else:
print(f"[{current_time}] 心跳检查: {len(connected_clients)}个客户端在线")
await asyncio.sleep(HEARTBEAT_INTERVAL)
# -------------------------- 6. 消息处理工具(新增消息解密)--------------------------
async def send_heartbeat_ack(conn: ClientConnection):
"""发送加密的心跳确认"""
try:
# 1. 构建原始心跳确认
heartbeat_ack_msg = {
"type": "heart",
"timestamp": get_current_time_str(),
"client_ip": conn.client_ip
}
# 2. AES加密
encrypted_msg = aes_encrypt(json.dumps(heartbeat_ack_msg))
# 3. 发送
await conn.websocket.send_json(encrypted_msg)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送加密心跳确认")
return True
except Exception as e:
connected_clients.pop(conn.client_ip, None)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 心跳确认失败 - {str(e)}")
return False
async def handle_text_msg(conn: ClientConnection, text: str):
"""处理加密的文本消息(如心跳)"""
try:
# 1. 解析加密字典
encrypted_dict = json.loads(text)
# 2. AES解密
decrypted_text = aes_decrypt(encrypted_dict)
# 3. 解析业务消息
msg = json.loads(decrypted_text)
if msg.get("type") == "heart":
conn.update_heartbeat()
await send_heartbeat_ack(conn)
else:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知文本类型({msg.get('type')}")
except json.JSONDecodeError:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON格式")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 文本消息解密失败 - {str(e)}")
async def handle_binary_msg(conn: ClientConnection, data: str):
"""处理加密的图像消息客户端需先转Base64+加密)"""
try:
# 1. 解密得到Base64编码的图像
encrypted_dict = json.loads(data)
decrypted_base64 = aes_decrypt(encrypted_dict)
# 2. Base64解码为二进制图像
frame_data = base64.b64decode(decrypted_base64)
# 3. 加入帧队列
conn.frame_queue.put_nowait(frame_data)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 解密后图像({len(frame_data)}字节)入队")
except asyncio.QueueFull:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 帧队列已满,丢弃数据")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 图像消息解密失败 - {str(e)}")
# -------------------------- 7. WebSocket路由与生命周期保持原有结构--------------------------
ws_router = APIRouter()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期:启动/停止心跳任务"""
global heartbeat_task
heartbeat_task = asyncio.create_task(heartbeat_checker())
print(f"[{get_current_time_str()}] 心跳检查任务启动ID: {id(heartbeat_task)}")
yield
# 关闭时清理
if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel()
await heartbeat_task
print(f"[{get_current_time_str()}] 心跳检查任务已取消")
@ws_router.websocket(WS_ENDPOINT)
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket连接处理入口"""
load_model() # 加载检测模型(仅一次)
await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown_ip"
current_time = get_current_time_str()
print(f"[{current_time}] 客户端{client_ip}: 连接已建立")
is_online_updated = False
try:
# 处理重复连接(关闭旧连接)
if client_ip in connected_clients:
old_conn = connected_clients[client_ip]
if old_conn.consumer_task and not old_conn.consumer_task.done():
old_conn.consumer_task.cancel()
await old_conn.websocket.close(code=1008, reason="新连接建立")
connected_clients.pop(client_ip)
print(f"[{current_time}] 客户端{client_ip}: 已关闭旧连接")
# 注册新连接
new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn
new_conn.start_consumer() # 启动帧消费
await new_conn.send_frame_permit() # 发送首次帧许可
# 标记客户端上线
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
action_data = DeviceActionCreate(client_ip=client_ip, action=1)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{client_ip}: 已标记为在线")
is_online_updated = True
except Exception as e:
print(f"[{current_time}] 客户端{client_ip}: 上线状态更新失败 - {str(e)}")
print(f"[{current_time}] 客户端{client_ip}: 连接注册成功,在线数: {len(connected_clients)}")
# 消息循环(接收客户端消息)
while True:
data = await websocket.receive()
if "text" in data:
# 处理加密文本消息(心跳、客户端指令)
await handle_text_msg(new_conn, data["text"])
elif "bytes" in data:
# 兼容客户端发送二进制先转Base64再处理
base64_data = base64.b64encode(data["bytes"]).decode("utf-8")
await handle_binary_msg(new_conn, base64_data)
except WebSocketDisconnect as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code}")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}")
finally:
# 清理资源(断开后处理)
if client_ip in connected_clients:
conn = connected_clients[client_ip]
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
# 仅上线成功的客户端,才标记离线
if is_online_updated:
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
action_data = DeviceActionCreate(client_ip=client_ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后标记为离线")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 离线更新失败 - {str(e)}")
connected_clients.pop(client_ip, None)
print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理,在线数: {len(connected_clients)}")