This commit is contained in:
ZZX9599
2025-09-04 17:08:25 +08:00
parent b5d870a19c
commit 08f8a0e44e

211
ws/ws.py
View File

@ -11,6 +11,8 @@ from schema.device_action_schema import DeviceActionCreate
import cv2 import cv2
import numpy as np import numpy as np
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
from queue import Queue
from threading import Lock
from ocr.model_violation_detector import MultiModelViolationDetector from ocr.model_violation_detector import MultiModelViolationDetector
@ -20,14 +22,8 @@ FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt"
OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml" OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"
KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces" KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces"
# 创建检测器实例 # 模型池配置根据GPU显存调整每个模型约占1G显存
detector = MultiModelViolationDetector( MODEL_POOL_SIZE = 3 # 最大并发客户端数
forbidden_words_path=FORBIDDEN_WORDS_PATH,
ocr_config_path=OCR_CONFIG_PATH,
yolo_model_path=YOLO_MODEL_PATH,
known_faces_dir=KNOWN_FACES_DIR,
ocr_confidence_threshold=0.5
)
# 配置常量 # 配置常量
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒) HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
@ -36,7 +32,39 @@ WS_ENDPOINT = "/ws" # WebSocket端点路径
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制 FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
# 工具函数:获取格式化时间字符串(统一时间戳格式) # 模型池实现 - 提前初始化固定数量的模型实例
class ModelPool:
def __init__(self, pool_size: int = MODEL_POOL_SIZE):
self.pool = Queue(maxsize=pool_size)
self.lock = Lock()
# 提前初始化模型实例(显存会在此阶段预分配)
for i in range(pool_size):
detector = MultiModelViolationDetector(
forbidden_words_path=FORBIDDEN_WORDS_PATH,
ocr_config_path=OCR_CONFIG_PATH,
yolo_model_path=YOLO_MODEL_PATH,
known_faces_dir=KNOWN_FACES_DIR,
ocr_confidence_threshold=0.5
)
self.pool.put(detector)
print(f"[{get_current_time_str()}] 模型池初始化:第{i + 1}/{pool_size}个模型加载完成")
def get_model(self) -> MultiModelViolationDetector:
"""从池子里获取模型(阻塞直到有可用实例)"""
with self.lock:
return self.pool.get()
def return_model(self, detector: MultiModelViolationDetector):
"""将模型归还给池子"""
with self.lock:
self.pool.put(detector)
# 初始化模型池(程序启动时加载所有模型,显存会一次性占用 MODEL_POOL_SIZE * 单模型显存)
model_pool = ModelPool(pool_size=MODEL_POOL_SIZE)
# 工具函数:获取格式化时间字符串
def get_current_time_str() -> str: def get_current_time_str() -> str:
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@ -54,12 +82,16 @@ class ClientConnection:
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE)
self.consumer_task: Optional[asyncio.Task] = None self.consumer_task: Optional[asyncio.Task] = None
# 从模型池获取专属模型(每个客户端独立占用一个模型实例)
self.detector = model_pool.get_model()
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:已获取模型池中的模型实例(显存独立)")
def update_heartbeat(self): def update_heartbeat(self):
"""更新心跳时间(客户端发送心跳时调用)""" """更新心跳时间"""
self.last_heartbeat = datetime.datetime.now() self.last_heartbeat = datetime.datetime.now()
def is_alive(self) -> bool: def is_alive(self) -> bool:
"""判断客户端是否存活(心跳超时检查)""" """判断客户端是否存活"""
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds() timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
return timeout < HEARTBEAT_TIMEOUT return timeout < HEARTBEAT_TIMEOUT
@ -68,11 +100,13 @@ class ClientConnection:
self.consumer_task = asyncio.create_task(self.consume_frames()) self.consumer_task = asyncio.create_task(self.consume_frames())
return self.consumer_task return self.consumer_task
def release_model(self):
"""客户端断开时归还模型到池"""
model_pool.return_model(self.detector)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:模型已归还至模型池(显存可复用)")
async def send_frame_permit(self): async def send_frame_permit(self):
""" """发送帧发送许可信号"""
发送「帧发送许可信号」
通知客户端可发送下一帧图像
"""
try: try:
frame_permit_msg = { frame_permit_msg = {
"type": "frame", "type": "frame",
@ -80,26 +114,24 @@ class ClientConnection:
"client_ip": self.client_ip "client_ip": self.client_ip
} }
await self.websocket.send_json(frame_permit_msg) await self.websocket.send_json(frame_permit_msg)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:已发送帧发送许可信号(取帧后立即通知)") print(f"[{get_current_time_str()}] 客户端{self.client_ip}:已发送帧发送许可信号")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧许可信号发送失败 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧许可信号发送失败 - {str(e)}")
async def consume_frames(self) -> None: async def consume_frames(self) -> None:
"""消费队列中的帧并处理(核心调整:取帧后立即发许可,再处理帧""" """消费队列中的帧并处理(并行执行核心)"""
try: try:
while True: while True:
# 1. 从队列取出帧(阻塞直到有帧可用) # 1. 从队列取出帧
frame_data = await self.frame_queue.get() frame_data = await self.frame_queue.get()
# -------------------------- 核心修改:取出帧后立即发送下一帧许可 -------------------------- # 2. 立即发送下一帧许可
await self.send_frame_permit() # 取帧即通知客户端发下一帧,无需等处理完成 await self.send_frame_permit()
# -----------------------------------------------------------------------------------------
try: try:
# 2. 处理取出的帧(即使处理慢,客户端也已收到许可,可提前准备下一帧 # 3. 并行处理帧用线程池执行AI检测真正并发
await self.process_frame(frame_data) await self.process_frame(frame_data)
finally: finally:
# 3. 标记帧任务完成(无论处理成功/失败,都需清理队列)
self.frame_queue.task_done() self.frame_queue.task_done()
except asyncio.CancelledError: except asyncio.CancelledError:
@ -108,7 +140,7 @@ class ClientConnection:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费逻辑错误 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费逻辑错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None: async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧图像数据(检测违规后发送危险通知 + 调用违规次数加一方法""" """处理单帧图像数据(使用客户端专属模型"""
# 二进制数据转OpenCV图像 # 二进制数据转OpenCV图像
nparr = np.frombuffer(frame_data, np.uint8) nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
@ -119,14 +151,18 @@ class ClientConnection:
# 确保图像保存目录存在 # 确保图像保存目录存在
os.makedirs('images', exist_ok=True) os.makedirs('images', exist_ok=True)
# 保存图像按IP+时间戳命名,避免冲突) # 保存图像
filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg" filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg"
try: try:
cv2.imwrite(filename, img) cv2.imwrite(filename, img)
print(f"[{get_current_time_str()}] 图像已保存至:{filename}") print(f"[{get_current_time_str()}] 图像已保存至:{filename}")
# 执行违规检测 # 关键修改:使用客户端专属模型 + 线程池并行执行AI检测
has_violation, violation_type, details = detector.detect_violations(img) has_violation, violation_type, details = await asyncio.to_thread(
self.detector.detect_violations, # 客户端独立模型
img # 输入图像
)
if has_violation: if has_violation:
print( print(
f"[{get_current_time_str()}] 客户端{self.client_ip}:检测到违规 - 类型: {violation_type}, 详情: {details}") f"[{get_current_time_str()}] 客户端{self.client_ip}:检测到违规 - 类型: {violation_type}, 详情: {details}")
@ -138,7 +174,7 @@ class ClientConnection:
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数更新失败 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数更新失败 - {str(e)}")
# 发送危险通知 # 发送危险通知
danger_msg = { danger_msg = {
"type": "danger", "type": "danger",
"timestamp": get_current_time_str(), "timestamp": get_current_time_str(),
@ -153,36 +189,48 @@ class ClientConnection:
# 全局状态管理 # 全局状态管理
connected_clients: Dict[str, ClientConnection] = {} connected_clients: Dict[str, ClientConnection] = {}
client_lock = asyncio.Lock() # 保护connected_clients的锁
heartbeat_task: Optional[asyncio.Task] = None heartbeat_task: Optional[asyncio.Task] = None
# 心跳检查(定时清理超时客户端 + 调用离线状态更新方法) # 心跳检查
async def heartbeat_checker(): async def heartbeat_checker():
while True: while True:
current_time = get_current_time_str() current_time = get_current_time_str()
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()] # 加锁保护字典遍历
async with client_lock:
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
if timeout_ips: if timeout_ips:
print(f"[{current_time}] 心跳检查:{len(timeout_ips)}个客户端超时IP{timeout_ips}") print(f"[{current_time}] 心跳检查:{len(timeout_ips)}个客户端超时IP{timeout_ips}")
for ip in timeout_ips: for ip in timeout_ips:
try: try:
conn = connected_clients[ip] async with client_lock:
if conn.consumer_task and not conn.consumer_task.done(): conn = connected_clients.get(ip)
conn.consumer_task.cancel() if not conn:
await conn.websocket.close(code=1008, reason="心跳超时") continue
# 超时设为离线并记录 if conn.consumer_task and not conn.consumer_task.done():
try: conn.consumer_task.cancel()
await asyncio.to_thread(update_online_status_by_ip, ip, 0) await conn.websocket.close(code=1008, reason="心跳超时")
action_data = DeviceActionCreate(client_ip=ip, action=0)
await asyncio.to_thread(add_device_action, action_data) # 归还模型
print(f"[{current_time}] 客户端{ip}:已标记为离线并记录操作") conn.release_model()
except Exception as e:
print(f"[{current_time}] 客户端{ip}:离线状态更新失败 - {str(e)}") # 超时设为离线并记录
try:
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
action_data = DeviceActionCreate(client_ip=ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{ip}:已标记为离线并记录操作")
except Exception as e:
print(f"[{current_time}] 客户端{ip}:离线状态更新失败 - {str(e)}")
finally: finally:
connected_clients.pop(ip, None) async with client_lock:
connected_clients.pop(ip, None)
else: else:
print(f"[{current_time}] 心跳检查:{len(connected_clients)}个客户端在线") async with client_lock:
print(f"[{current_time}] 心跳检查:{len(connected_clients)}个客户端在线")
await asyncio.sleep(HEARTBEAT_INTERVAL) await asyncio.sleep(HEARTBEAT_INTERVAL)
@ -215,7 +263,8 @@ async def send_heartbeat_ack(conn: ClientConnection):
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:已发送心跳确认") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:已发送心跳确认")
return True return True
except Exception as e: except Exception as e:
connected_clients.pop(conn.client_ip, None) async with client_lock:
connected_clients.pop(conn.client_ip, None)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:心跳确认发送失败 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:心跳确认发送失败 - {str(e)}")
return False return False
@ -252,22 +301,26 @@ async def websocket_endpoint(websocket: WebSocket):
print(f"[{current_time}] 客户端{client_ip}WebSocket连接已建立") print(f"[{current_time}] 客户端{client_ip}WebSocket连接已建立")
is_online_updated = False is_online_updated = False
new_conn = None
try: try:
# 处理重复连接 # 处理重复连接
if client_ip in connected_clients: async with client_lock:
old_conn = connected_clients[client_ip] if client_ip in connected_clients:
if old_conn.consumer_task and not old_conn.consumer_task.done(): old_conn = connected_clients[client_ip]
old_conn.consumer_task.cancel() if old_conn.consumer_task and not old_conn.consumer_task.done():
await old_conn.websocket.close(code=1008, reason="同一IP新连接建立") old_conn.consumer_task.cancel()
connected_clients.pop(client_ip) await old_conn.websocket.close(code=1008, reason="同一IP新连接建立")
print(f"[{current_time}] 客户端{client_ip}:已关闭旧连接") old_conn.release_model() # 归还旧连接的模型
connected_clients.pop(client_ip)
print(f"[{current_time}] 客户端{client_ip}:已关闭旧连接并回收模型")
# 注册新连接 # 注册新连接
new_conn = ClientConnection(websocket, client_ip) new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn async with client_lock:
connected_clients[client_ip] = new_conn
new_conn.start_consumer() new_conn.start_consumer()
# 初始许可:连接建立后立即发一次,让客户端知道可发第一帧(后续靠取帧后自动发) # 初始许可:连接建立后立即发一次
await new_conn.send_frame_permit() await new_conn.send_frame_permit()
# 标记上线并记录 # 标记上线并记录
@ -280,7 +333,8 @@ async def websocket_endpoint(websocket: WebSocket):
except Exception as e: except Exception as e:
print(f"[{current_time}] 客户端{client_ip}:上线状态更新失败 - {str(e)}") print(f"[{current_time}] 客户端{client_ip}:上线状态更新失败 - {str(e)}")
print(f"[{current_time}] 客户端{client_ip}:新连接注册成功,在线数:{len(connected_clients)}") async with client_lock:
print(f"[{current_time}] 客户端{client_ip}:新连接注册成功,在线数:{len(connected_clients)}")
# 消息循环 # 消息循环
while True: while True:
@ -296,20 +350,37 @@ async def websocket_endpoint(websocket: WebSocket):
print(f"[{get_current_time_str()}] 客户端{client_ip}:连接异常 - {str(e)[:50]}") print(f"[{get_current_time_str()}] 客户端{client_ip}:连接异常 - {str(e)[:50]}")
finally: finally:
# 清理资源并标记离线 # 清理资源并标记离线
if client_ip in connected_clients: if new_conn and client_ip in connected_clients:
conn = connected_clients[client_ip] async with client_lock:
if conn.consumer_task and not conn.consumer_task.done(): conn = connected_clients.get(client_ip)
conn.consumer_task.cancel() if conn:
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
# 主动/异常断开时标记离线 # 归还模型到模型池
if is_online_updated: conn.release_model()
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
action_data = DeviceActionCreate(client_ip=client_ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后已标记为离线")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后离线更新失败 - {str(e)}")
connected_clients.pop(client_ip, None) # 主动/异常断开时标记离线
print(f"[{get_current_time_str()}] 客户端{client_ip}:资源已清理,在线数:{len(connected_clients)}") if is_online_updated:
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
action_data = DeviceActionCreate(client_ip=client_ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后已标记为离线")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后离线更新失败 - {str(e)}")
connected_clients.pop(client_ip, None)
async with client_lock:
print(f"[{get_current_time_str()}] 客户端{client_ip}:资源已清理,在线数:{len(connected_clients)}")
# 创建FastAPI应用
app = FastAPI(lifespan=lifespan)
app.include_router(ws_router)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)