Files
video/ws/ws.py
2025-09-03 17:02:22 +08:00

199 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import datetime
import json
import os
from contextlib import asynccontextmanager
from typing import Dict, Optional
import cv2
import numpy as np
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
from ocr.model_violation_detector import MultiModelViolationDetector
# 配置文件相对路径(根据实际目录结构调整)
YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt" # 关键修正从core目录向上一级找ocr文件夹
FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt"
OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"
KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces"
# 创建检测器实例
detector = MultiModelViolationDetector(
forbidden_words_path=FORBIDDEN_WORDS_PATH,
ocr_config_path=OCR_CONFIG_PATH,
yolo_model_path=YOLO_MODEL_PATH,
known_faces_dir=KNOWN_FACES_DIR,
ocr_confidence_threshold=0.5
)
# -------------------------- 配置常量(简化硬编码) --------------------------
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
WS_ENDPOINT = "/ws" # WebSocket端点路径
# -------------------------- 核心数据结构与全局变量 --------------------------
ws_router = APIRouter()
# 客户端连接封装(仅保留核心属性和方法)
class ClientConnection:
def __init__(self, websocket: WebSocket, client_ip: str):
self.websocket = websocket
self.client_ip = client_ip
self.last_heartbeat = datetime.datetime.now()
# 更新心跳时间
def update_heartbeat(self):
self.last_heartbeat = datetime.datetime.now()
# 检查是否存活超时返回False
def is_alive(self) -> bool:
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
return timeout < HEARTBEAT_TIMEOUT
# 全局连接管理IP -> 连接实例)
connected_clients: Dict[str, ClientConnection] = {}
# 心跳任务(全局引用,用于关闭时清理)
heartbeat_task: Optional[asyncio.Task] = None
# -------------------------- 心跳检查逻辑(精简日志) --------------------------
async def heartbeat_checker():
while True:
now = datetime.datetime.now()
# 1. 筛选超时客户端(避免遍历中修改字典)
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
# 2. 处理超时连接(关闭+移除)
if timeout_ips:
print(f"[{now:%H:%M:%S}] 心跳检查:{len(timeout_ips)}个客户端超时({timeout_ips}")
for ip in timeout_ips:
try:
await connected_clients[ip].websocket.close(code=1008, reason="心跳超时")
finally:
connected_clients.pop(ip, None)
else:
print(f"[{now:%H:%M:%S}] 心跳检查:{len(connected_clients)}个客户端在线,无超时")
# 3. 等待下一轮检查
await asyncio.sleep(HEARTBEAT_INTERVAL)
# -------------------------- 应用生命周期(简化异常处理) --------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
global heartbeat_task
# 启动心跳任务
heartbeat_task = asyncio.create_task(heartbeat_checker())
print(f"[{datetime.datetime.now():%H:%M:%S}] 心跳任务启动ID{id(heartbeat_task)}")
yield
# 关闭时取消心跳任务
if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel()
try:
await heartbeat_task
print(f"[{datetime.datetime.now():%H:%M:%S}] 心跳任务已取消")
except asyncio.CancelledError:
pass
# -------------------------- 消息处理(合并冗余逻辑) --------------------------
async def send_heartbeat_ack(client_ip: str):
"""回复心跳确认"""
if client_ip not in connected_clients:
return False
try:
ack = {
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"type": "heartbeat"
}
await connected_clients[client_ip].websocket.send_json(ack)
return True
except Exception:
connected_clients.pop(client_ip, None)
return False
async def handle_text_msg(client_ip: str, text: str, conn: ClientConnection):
"""处理文本消息(核心:心跳+JSON解析"""
try:
msg = json.loads(text)
# 仅处理心跳类型消息
if msg.get("type") == "heartbeat":
conn.update_heartbeat()
await send_heartbeat_ack(client_ip)
else:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:收到消息:{msg}")
except json.JSONDecodeError:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}无效JSON消息")
async def handle_binary_msg(client_ip: str, data: bytes):
"""处理二进制消息(保留扩展入口)"""
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:收到{len(data)}字节二进制数据")
# 将二进制数据转换为NumPy数组uint8类型
nparr = np.frombuffer(data, np.uint8)
# 解码为图像返回与cv2.imread相同的格式BGR通道的ndarray
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
#转存到本地images文件夹下
# 确保images文件夹存在
if not os.path.exists('images'):
os.makedirs('images')
# 生成唯一的文件名包含时间戳和客户端IP避免文件名冲突
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"images/{client_ip.replace('.', '_')}_{timestamp}.jpg"
# 保存图像到本地
cv2.imwrite(filename, img)
print(f"[{datetime.datetime.now():%H:%M:%S}] 图像已保存至:{filename}")
# 进行检测
if img is not None:
has_violation, violation_type, details = detector.detect_violations(img)
if has_violation:
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
else:
print("未检测到任何违规内容")
else:
print(f"无法读取测试图像:")
# -------------------------- WebSocket核心端点 --------------------------
@ws_router.websocket(WS_ENDPOINT)
async def websocket_endpoint(websocket: WebSocket):
# 接受连接 + 获取客户端IP
await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown"
now = datetime.datetime.now()
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:连接成功")
try:
# 处理重复连接(关闭旧连接)
if client_ip in connected_clients:
await connected_clients[client_ip].websocket.close(code=1008, reason="同一IP新连接")
connected_clients.pop(client_ip)
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:关闭旧连接")
# 注册新连接
new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:注册成功,当前在线{len(connected_clients)}")
# 循环接收消息
while True:
data = await websocket.receive()
if "text" in data:
await handle_text_msg(client_ip, data["text"], new_conn)
elif "bytes" in data:
await handle_binary_msg(client_ip, data["bytes"])
# 异常处理(断开/错误)
except WebSocketDisconnect as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:主动断开(代码:{e.code}")
except Exception as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接异常({str(e)[:50]}")
finally:
# 清理连接
connected_clients.pop(client_ip, None)
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接已清理,当前在线{len(connected_clients)}")