This commit is contained in:
ZZX9599
2025-09-03 17:08:28 +08:00
parent 9d940e7fd2
commit d9198229aa

147
ws/ws.py
View File

@ -3,16 +3,16 @@ import datetime
import json import json
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Dict, Optional from typing import Dict, Optional, AsyncGenerator
import cv2 import cv2
import numpy as np import numpy as np
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
from ocr.model_violation_detector import MultiModelViolationDetector from ocr.model_violation_detector import MultiModelViolationDetector
# 配置文件相对路径(根据实际目录结构调整) # 配置文件相对路径(根据实际目录结构调整)
YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt" # 关键修正从core目录向上一级找ocr文件夹 YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt" FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt"
OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml" OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"
KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces" KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces"
@ -26,21 +26,24 @@ detector = MultiModelViolationDetector(
ocr_confidence_threshold=0.5 ocr_confidence_threshold=0.5
) )
# -------------------------- 配置常量(简化硬编码) -------------------------- # -------------------------- 配置常量 --------------------------
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒) HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒) HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
WS_ENDPOINT = "/ws" # WebSocket端点路径 WS_ENDPOINT = "/ws" # WebSocket端点路径
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
# -------------------------- 核心数据结构与全局变量 -------------------------- # -------------------------- 核心数据结构与全局变量 --------------------------
ws_router = APIRouter() ws_router = APIRouter()
# 客户端连接封装(仅保留核心属性和方法 # 客户端连接封装(包含帧队列
class ClientConnection: class ClientConnection:
def __init__(self, websocket: WebSocket, client_ip: str): def __init__(self, websocket: WebSocket, client_ip: str):
self.websocket = websocket self.websocket = websocket
self.client_ip = client_ip self.client_ip = client_ip
self.last_heartbeat = datetime.datetime.now() self.last_heartbeat = datetime.datetime.now()
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) # 帧队列长度为1
self.consumer_task: Optional[asyncio.Task] = None # 消费者任务
# 更新心跳时间 # 更新心跳时间
def update_heartbeat(self): def update_heartbeat(self):
@ -51,6 +54,69 @@ class ClientConnection:
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds() timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
return timeout < HEARTBEAT_TIMEOUT return timeout < HEARTBEAT_TIMEOUT
# 启动帧消费任务
def start_consumer(self):
self.consumer_task = asyncio.create_task(self.consume_frames())
return self.consumer_task
# 帧消费协程
async def consume_frames(self) -> None:
"""从队列中获取帧并进行处理"""
try:
while True:
# 从队列获取帧数据
frame_data = await self.frame_queue.get()
try:
# 处理帧数据
await self.process_frame(frame_data)
finally:
# 标记任务完成
self.frame_queue.task_done()
except asyncio.CancelledError:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:帧消费任务已取消")
except Exception as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:帧处理错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧图像数据"""
# 将二进制数据转换为NumPy数组uint8类型
nparr = np.frombuffer(frame_data, np.uint8)
# 解码为图像返回与cv2.imread相同的格式BGR通道的ndarray
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# 确保images文件夹存在
if not os.path.exists('images'):
os.makedirs('images')
# 生成唯一的文件名包含时间戳和客户端IP避免文件名冲突
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"images/{self.client_ip.replace('.', '_')}_{timestamp}.jpg"
try:
# 保存图像到本地
cv2.imwrite(filename, img)
print(f"[{datetime.datetime.now():%H:%M:%S}] 图像已保存至:{filename}")
# 进行检测
if img is not None:
has_violation, violation_type, details = detector.detect_violations(img)
if has_violation:
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
# 可以在这里添加发送检测结果回客户端的逻辑
await self.websocket.send_json({
"type": "detection_result",
"has_violation": has_violation,
"violation_type": violation_type,
"details": details,
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
})
else:
print("未检测到任何违规内容")
else:
print(f"无法解析图像数据")
except Exception as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 图像处理错误 - {str(e)}")
# 全局连接管理IP -> 连接实例) # 全局连接管理IP -> 连接实例)
connected_clients: Dict[str, ClientConnection] = {} connected_clients: Dict[str, ClientConnection] = {}
@ -58,7 +124,7 @@ connected_clients: Dict[str, ClientConnection] = {}
heartbeat_task: Optional[asyncio.Task] = None heartbeat_task: Optional[asyncio.Task] = None
# -------------------------- 心跳检查逻辑(精简日志) -------------------------- # -------------------------- 心跳检查逻辑 --------------------------
async def heartbeat_checker(): async def heartbeat_checker():
while True: while True:
now = datetime.datetime.now() now = datetime.datetime.now()
@ -70,6 +136,9 @@ async def heartbeat_checker():
print(f"[{now:%H:%M:%S}] 心跳检查:{len(timeout_ips)}个客户端超时({timeout_ips}") print(f"[{now:%H:%M:%S}] 心跳检查:{len(timeout_ips)}个客户端超时({timeout_ips}")
for ip in timeout_ips: for ip in timeout_ips:
try: try:
# 取消消费者任务
if connected_clients[ip].consumer_task and not connected_clients[ip].consumer_task.done():
connected_clients[ip].consumer_task.cancel()
await connected_clients[ip].websocket.close(code=1008, reason="心跳超时") await connected_clients[ip].websocket.close(code=1008, reason="心跳超时")
finally: finally:
connected_clients.pop(ip, None) connected_clients.pop(ip, None)
@ -80,7 +149,7 @@ async def heartbeat_checker():
await asyncio.sleep(HEARTBEAT_INTERVAL) await asyncio.sleep(HEARTBEAT_INTERVAL)
# -------------------------- 应用生命周期(简化异常处理) -------------------------- # -------------------------- 应用生命周期 --------------------------
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
global heartbeat_task global heartbeat_task
@ -98,7 +167,7 @@ async def lifespan(app: FastAPI):
pass pass
# -------------------------- 消息处理(合并冗余逻辑) -------------------------- # -------------------------- 消息处理 --------------------------
async def send_heartbeat_ack(client_ip: str): async def send_heartbeat_ack(client_ip: str):
"""回复心跳确认""" """回复心跳确认"""
if client_ip not in connected_clients: if client_ip not in connected_clients:
@ -130,33 +199,28 @@ async def handle_text_msg(client_ip: str, text: str, conn: ClientConnection):
async def handle_binary_msg(client_ip: str, data: bytes): async def handle_binary_msg(client_ip: str, data: bytes):
"""处理二进制消息(保留扩展入口""" """处理二进制消息(使用队列控制帧处理"""
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:收到{len(data)}字节二进制数据") if client_ip not in connected_clients:
# 将二进制数据转换为NumPy数组uint8类型 print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接不存在,丢弃{len(data)}字节数据")
nparr = np.frombuffer(data, np.uint8) return
# 解码为图像返回与cv2.imread相同的格式BGR通道的ndarray
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
#转存到本地images文件夹下
# 确保images文件夹存在
if not os.path.exists('images'):
os.makedirs('images')
# 生成唯一的文件名包含时间戳和客户端IP避免文件名冲突 conn = connected_clients[client_ip]
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"images/{client_ip.replace('.', '_')}_{timestamp}.jpg" # 检查队列是否已满
# 保存图像到本地 if conn.frame_queue.full():
# 队列已满,丢弃当前帧
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:队列已满,丢弃{len(data)}字节数据")
return
# 队列未满,添加帧到队列
try:
# 非阻塞添加(因为已检查队列未满,所以不会阻塞)
conn.frame_queue.put_nowait(data)
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:已添加{len(data)}字节数据到队列")
except asyncio.QueueFull:
# 理论上不会走到这里,因为上面已检查
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:队列突然满了,丢弃{len(data)}字节数据")
cv2.imwrite(filename, img)
print(f"[{datetime.datetime.now():%H:%M:%S}] 图像已保存至:{filename}")
# 进行检测
if img is not None:
has_violation, violation_type, details = detector.detect_violations(img)
if has_violation:
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
else:
print("未检测到任何违规内容")
else:
print(f"无法读取测试图像:")
# -------------------------- WebSocket核心端点 -------------------------- # -------------------------- WebSocket核心端点 --------------------------
@ws_router.websocket(WS_ENDPOINT) @ws_router.websocket(WS_ENDPOINT)
@ -167,9 +231,13 @@ async def websocket_endpoint(websocket: WebSocket):
now = datetime.datetime.now() now = datetime.datetime.now()
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:连接成功") print(f"[{now:%H:%M:%S}] 客户端{client_ip}:连接成功")
consumer_task = None
try: try:
# 处理重复连接(关闭旧连接) # 处理重复连接(关闭旧连接)
if client_ip in connected_clients: if client_ip in connected_clients:
# 取消旧连接的消费者任务
if connected_clients[client_ip].consumer_task and not connected_clients[client_ip].consumer_task.done():
connected_clients[client_ip].consumer_task.cancel()
await connected_clients[client_ip].websocket.close(code=1008, reason="同一IP新连接") await connected_clients[client_ip].websocket.close(code=1008, reason="同一IP新连接")
connected_clients.pop(client_ip) connected_clients.pop(client_ip)
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:关闭旧连接") print(f"[{now:%H:%M:%S}] 客户端{client_ip}:关闭旧连接")
@ -177,7 +245,10 @@ async def websocket_endpoint(websocket: WebSocket):
# 注册新连接 # 注册新连接
new_conn = ClientConnection(websocket, client_ip) new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn connected_clients[client_ip] = new_conn
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:注册成功,当前在线{len(connected_clients)}")
# 启动帧消费任务
consumer_task = new_conn.start_consumer()
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:注册成功,已启动帧消费任务,当前在线{len(connected_clients)}")
# 循环接收消息 # 循环接收消息
while True: while True:
@ -193,6 +264,10 @@ async def websocket_endpoint(websocket: WebSocket):
except Exception as e: except Exception as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接异常({str(e)[:50]}") print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接异常({str(e)[:50]}")
finally: finally:
# 清理连接 # 清理连接和任务
connected_clients.pop(client_ip, None) if client_ip in connected_clients:
# 取消消费者任务
if connected_clients[client_ip].consumer_task and not connected_clients[client_ip].consumer_task.done():
connected_clients[client_ip].consumer_task.cancel()
connected_clients.pop(client_ip, None)
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接已清理,当前在线{len(connected_clients)}") print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接已清理,当前在线{len(connected_clients)}")