优化代码风格

This commit is contained in:
ZZX9599
2025-09-15 18:51:18 +08:00
parent 3cb83b292e
commit 4549e67a68

247
ws/ws.py
View File

@ -2,38 +2,81 @@ import asyncio
import datetime
import json
import os
import base64
from contextlib import asynccontextmanager
from typing import Dict, Optional
import cv2
import numpy as np
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
from service.device_action_service import add_device_action
from schema.device_action_schema import DeviceActionCreate
from core.all import detect, load_model
import cv2
import numpy as np
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
# -------------------------- 1. AES 加密解密工具(固定密钥)--------------------------
AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa" # 约定密钥32字节
AES_BLOCK_SIZE = 16 # AES固定块大小
# 配置常量
def aes_encrypt(plaintext: str) -> dict:
"""AES-CBC加密返回{密文, IV, 算法标识}均Base64编码"""
try:
iv = os.urandom(AES_BLOCK_SIZE) # 随机IV16字节
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv)
# 明文填充+加密+Base64编码
padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE)
ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8")
iv_base64 = base64.b64encode(iv).decode("utf-8")
return {
"ciphertext": ciphertext,
"iv": iv_base64,
"algorithm": "AES-CBC"
}
except Exception as e:
raise Exception(f"AES加密失败: {str(e)}") from e
def aes_decrypt(encrypted_dict: dict) -> str:
"""AES-CBC解密输入加密字典返回原始文本"""
try:
# Base64解码密文和IV
ciphertext = base64.b64decode(encrypted_dict["ciphertext"])
iv = base64.b64decode(encrypted_dict["iv"])
# 解密+去除填充
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv)
decrypted = unpad(cipher.decrypt(ciphertext), AES_BLOCK_SIZE).decode("utf-8")
return decrypted
except Exception as e:
raise Exception(f"AES解密失败: {str(e)}") from e
# -------------------------- 2. 配置常量(保持原有)--------------------------
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
WS_ENDPOINT = "/ws" # WebSocket端点路径
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
# 工具函数: 获取格式化时间字符串
# -------------------------- 3. 工具函数(保持原有)--------------------------
def get_current_time_str() -> str:
"""获取格式化时间字符串"""
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def get_current_time_file_str() -> str:
"""获取文件命名用时间字符串"""
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
# 客户端连接封装
# -------------------------- 4. 客户端连接封装(新增消息加密)--------------------------
class ClientConnection:
def __init__(self, websocket: WebSocket, client_ip: str):
self.websocket = websocket
self.client_ip = client_ip # 已初始化客户端IP用于传递给detect
self.client_ip = client_ip
self.last_heartbeat = datetime.datetime.now()
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE)
self.consumer_task: Optional[asyncio.Task] = None
@ -53,96 +96,89 @@ class ClientConnection:
return self.consumer_task
async def send_frame_permit(self):
"""发送帧发送许可信号"""
"""发送加密的帧许可信号"""
try:
# 1. 构建原始消息
frame_permit_msg = {
"type": "frame",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip
}
await self.websocket.send_json(frame_permit_msg)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送帧发送许可信号")
# 2. AES加密消息
encrypted_msg = aes_encrypt(json.dumps(frame_permit_msg))
# 3. 发送加密消息
await self.websocket.send_json(encrypted_msg)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送加密帧许可")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可信号发送失败 - {str(e)}")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可加密/发送失败 - {str(e)}")
async def consume_frames(self) -> None:
"""消费队列中的帧并处理"""
try:
while True:
# 取出帧并立即发送下一帧许可
frame_data = await self.frame_queue.get()
await self.send_frame_permit()
await self.send_frame_permit() # 发送下一帧许可
try:
await self.process_frame(frame_data)
finally:
self.frame_queue.task_done()
except asyncio.CancelledError:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费任务已取消")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧图像数据核心修改detect函数传入 client_ip + img 双参数"""
"""处理单帧图像(含加密危险通知"""
# 二进制转OpenCV图像
nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析图像数据")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析图像")
return
try:
# -------------------------- 核心修改按要求传入参数1.client_ip 2.img --------------------------
# detect函数参数顺序第一个为client_ip第二个为图像数据img
# 保持返回值解包(是否违规, 结果数据, 检测器类型)不变
# 调用检测函数(client_ip + img 双参数
has_violation, data, detector_type = await asyncio.to_thread(
detect, # 调用检测函数
self.client_ip, # 第一个参数客户端IP新增按需求顺序
img # 第二个参数:图像数据(原参数,调整顺序)
detect, self.client_ip, img
)
# -------------------------------------------------------------------------------------
print(
f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - 违规: {has_violation}, 类型: {detector_type}")
# 打印检测结果包含客户端IP与传入参数对应
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - "
f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}")
# 处理违规逻辑逻辑不变基于detect返回结果执行
# 处理违规逻辑(发送加密危险通知
if has_violation:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - "
f"类型: {detector_type}, 详情: {data}")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - {data}")
# 违规次数+1
try:
await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数+1")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数+1")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}")
# 发送危险通知
# 1. 构建原始危险通知
danger_msg = {
"type": "danger",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip,
"detail": data
}
await self.websocket.send_json(danger_msg)
else:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 未检测到违规")
# 2. AES加密通知
encrypted_danger_msg = aes_encrypt(json.dumps(danger_msg))
# 3. 发送加密通知
await self.websocket.send_json(encrypted_danger_msg)
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}")
# 全局状态管理
# -------------------------- 5. 全局状态与心跳管理(保持原有)--------------------------
connected_clients: Dict[str, ClientConnection] = {}
heartbeat_task: Optional[asyncio.Task] = None
# 心跳检查任务
async def heartbeat_checker():
"""全局心跳检查任务"""
while True:
current_time = get_current_time_str()
# 筛选超时客户端
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
if timeout_ips:
@ -150,18 +186,17 @@ async def heartbeat_checker():
for ip in timeout_ips:
try:
conn = connected_clients[ip]
# 取消消费任务+关闭连接
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
await conn.websocket.close(code=1008, reason="心跳超时")
# 标记离线
try:
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
action_data = DeviceActionCreate(client_ip=ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{ip}: 已标记为离线并记录操作")
except Exception as e:
print(f"[{current_time}] 客户端{ip}: 离线状态更新失败 - {str(e)}")
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
action_data = DeviceActionCreate(client_ip=ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{ip}: 已标记为离线")
except Exception as e:
print(f"[{current_time}] 客户端{ip}: 离线处理失败 - {str(e)}")
finally:
connected_clients.pop(ip, None)
else:
@ -170,87 +205,108 @@ async def heartbeat_checker():
await asyncio.sleep(HEARTBEAT_INTERVAL)
# 应用生命周期管理
@asynccontextmanager
async def lifespan(app: FastAPI):
global heartbeat_task
heartbeat_task = asyncio.create_task(heartbeat_checker())
print(f"[{get_current_time_str()}] 全局心跳检查任务启动任务ID: {id(heartbeat_task)}")
yield
if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel()
try:
await heartbeat_task
print(f"[{get_current_time_str()}] 全局心跳检查任务已取消")
except asyncio.CancelledError:
pass
# 消息处理工具函数
# -------------------------- 6. 消息处理工具(新增消息解密)--------------------------
async def send_heartbeat_ack(conn: ClientConnection):
"""发送加密的心跳确认"""
try:
# 1. 构建原始心跳确认
heartbeat_ack_msg = {
"type": "heart",
"timestamp": get_current_time_str(),
"client_ip": conn.client_ip
}
await conn.websocket.send_json(heartbeat_ack_msg)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送心跳确认")
# 2. AES加密
encrypted_msg = aes_encrypt(json.dumps(heartbeat_ack_msg))
# 3. 发送
await conn.websocket.send_json(encrypted_msg)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送加密心跳确认")
return True
except Exception as e:
connected_clients.pop(conn.client_ip, None)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 心跳确认发送失败 - {str(e)}")
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 心跳确认失败 - {str(e)}")
return False
async def handle_text_msg(conn: ClientConnection, text: str):
"""处理加密的文本消息(如心跳)"""
try:
msg = json.loads(text)
# 1. 解析加密字典
encrypted_dict = json.loads(text)
# 2. AES解密
decrypted_text = aes_decrypt(encrypted_dict)
# 3. 解析业务消息
msg = json.loads(decrypted_text)
if msg.get("type") == "heart":
conn.update_heartbeat()
await send_heartbeat_ack(conn)
else:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知文本消息类型({msg.get('type')}")
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知文本类型({msg.get('type')}")
except json.JSONDecodeError:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON文本消息")
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON格式")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 文本消息解密失败 - {str(e)}")
async def handle_binary_msg(conn: ClientConnection, data: bytes):
async def handle_binary_msg(conn: ClientConnection, data: str):
"""处理加密的图像消息客户端需先转Base64+加密)"""
try:
conn.frame_queue.put_nowait(data)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 图像数据({len(data)}字节)已加入队列")
# 1. 解密得到Base64编码的图像
encrypted_dict = json.loads(data)
decrypted_base64 = aes_decrypt(encrypted_dict)
# 2. Base64解码为二进制图像
frame_data = base64.b64decode(decrypted_base64)
# 3. 加入帧队列
conn.frame_queue.put_nowait(frame_data)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 解密后图像({len(frame_data)}字节)入队")
except asyncio.QueueFull:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 帧队列已满丢弃当前图像数据")
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 帧队列已满丢弃数据")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 图像消息解密失败 - {str(e)}")
# WebSocket路由配置
# -------------------------- 7. WebSocket路由与生命周期保持原有结构--------------------------
ws_router = APIRouter()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期:启动/停止心跳任务"""
global heartbeat_task
heartbeat_task = asyncio.create_task(heartbeat_checker())
print(f"[{get_current_time_str()}] 心跳检查任务启动ID: {id(heartbeat_task)}")
yield
# 关闭时清理
if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel()
await heartbeat_task
print(f"[{get_current_time_str()}] 心跳检查任务已取消")
@ws_router.websocket(WS_ENDPOINT)
async def websocket_endpoint(websocket: WebSocket):
load_model() # 加载检测模型(仅在连接建立时加载一次,避免重复加载)
"""WebSocket连接处理入口"""
load_model() # 加载检测模型(仅一次)
await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown_ip"
current_time = get_current_time_str()
print(f"[{current_time}] 客户端{client_ip}: WebSocket连接已建立")
print(f"[{current_time}] 客户端{client_ip}: 连接已建立")
is_online_updated = False
try:
# 处理重复连接(同一IP断开旧连接)
# 处理重复连接(关闭旧连接)
if client_ip in connected_clients:
old_conn = connected_clients[client_ip]
if old_conn.consumer_task and not old_conn.consumer_task.done():
old_conn.consumer_task.cancel()
await old_conn.websocket.close(code=1008, reason="同一IP新连接建立")
await old_conn.websocket.close(code=1008, reason="新连接建立")
connected_clients.pop(client_ip)
print(f"[{current_time}] 客户端{client_ip}: 已关闭旧连接")
# 注册新连接绑定client_ip和WebSocket
# 注册新连接
new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn
new_conn.start_consumer() # 启动帧消费任务
new_conn.start_consumer() # 启动帧消费
await new_conn.send_frame_permit() # 发送首次帧许可
# 标记客户端上线
@ -258,41 +314,42 @@ async def websocket_endpoint(websocket: WebSocket):
await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
action_data = DeviceActionCreate(client_ip=client_ip, action=1)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{client_ip}: 已标记为在线并记录操作")
print(f"[{current_time}] 客户端{client_ip}: 已标记为在线")
is_online_updated = True
except Exception as e:
print(f"[{current_time}] 客户端{client_ip}: 上线状态更新失败 - {str(e)}")
print(f"[{current_time}] 客户端{client_ip}: 连接注册成功在线数: {len(connected_clients)}")
print(f"[{current_time}] 客户端{client_ip}: 连接注册成功在线数: {len(connected_clients)}")
# 消息循环(持续接收客户端消息)
# 消息循环(接收客户端消息)
while True:
data = await websocket.receive()
if "text" in data:
# 处理加密文本消息(心跳、客户端指令)
await handle_text_msg(new_conn, data["text"])
elif "bytes" in data:
await handle_binary_msg(new_conn, data["bytes"])
# 兼容客户端发送二进制先转Base64再处理
base64_data = base64.b64encode(data["bytes"]).decode("utf-8")
await handle_binary_msg(new_conn, base64_data)
except WebSocketDisconnect as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code}")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}")
finally:
# 清理资源(断开后标记离线+删除连接
# 清理资源(断开后处理
if client_ip in connected_clients:
conn = connected_clients[client_ip]
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
# 仅当上线状态更新成功时,才执行离线更新
# 仅上线成功的客户端,才标记离线
if is_online_updated:
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
action_data = DeviceActionCreate(client_ip=client_ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后标记为离线")
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后标记为离线")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后离线更新失败 - {str(e)}")
print(f"[{get_current_time_str()}] 客户端{client_ip}: 离线更新失败 - {str(e)}")
connected_clients.pop(client_ip, None)
print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理在线数: {len(connected_clients)}")
print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理在线数: {len(connected_clients)}")