修改WS兼容检测的Future对象
This commit is contained in:
59
ws/ws.py
59
ws/ws.py
@ -7,12 +7,12 @@ from typing import Dict, Optional, AsyncGenerator
|
||||
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
|
||||
from service.device_action_service import add_device_action
|
||||
from schema.device_action_schema import DeviceActionCreate
|
||||
from core.all import detect
|
||||
# 【修改1:导入detect和TIMEOUT(用于检测超时控制)】
|
||||
from core.all import detect, load_model, TIMEOUT
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
|
||||
from core.all import load_model
|
||||
|
||||
# 配置常量
|
||||
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
|
||||
@ -93,7 +93,7 @@ class ClientConnection:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}")
|
||||
|
||||
async def process_frame(self, frame_data: bytes) -> None:
|
||||
"""处理单帧图像数据(检测违规后发送危险通知 + 调用违规次数加一方法)"""
|
||||
"""处理单帧图像数据(【核心修改:等待检测结果+修正解包】)"""
|
||||
# 二进制数据转OpenCV图像
|
||||
nparr = np.frombuffer(frame_data, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
@ -109,13 +109,36 @@ class ClientConnection:
|
||||
try:
|
||||
cv2.imwrite(filename, img)
|
||||
print(f"[{get_current_time_str()}] 图像已保存至: {filename}")
|
||||
has_violation, data, type = detect(img)
|
||||
print(has_violation)
|
||||
print(type)
|
||||
print(data)
|
||||
|
||||
# -------------------------- 【核心修改1:提交检测任务并等待结果】 --------------------------
|
||||
# 1. 提交检测任务获取Future对象(非阻塞)
|
||||
detection_future = detect(img)
|
||||
# 2. 用asyncio.to_thread等待Future结果(避免阻塞asyncio事件循环),设置超时
|
||||
try:
|
||||
# 解包4元素结果:(是否违规, 结果数据, 检测器类型, 任务ID)
|
||||
has_violation, data, detector_type, task_id = await asyncio.to_thread(
|
||||
detection_future.result, # 调用Future的result()获取实际结果
|
||||
timeout=TIMEOUT # 超时控制(与all.py配置一致)
|
||||
)
|
||||
except TimeoutError:
|
||||
# 处理检测超时场景
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测任务超时(超过{TIMEOUT}秒)")
|
||||
has_violation = False
|
||||
data = f"检测超时(超过{TIMEOUT}秒)"
|
||||
detector_type = "timeout"
|
||||
task_id = -1 # 超时任务ID标记为-1
|
||||
# -----------------------------------------------------------------------------------------
|
||||
|
||||
# -------------------------- 【核心修改2:修正日志打印变量名】 --------------------------
|
||||
# 打印检测结果(避免使用Python关键字"type")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - "
|
||||
f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}, 任务ID: {task_id}")
|
||||
# -----------------------------------------------------------------------------------------
|
||||
|
||||
# 处理违规逻辑(变量名从type改为detector_type)
|
||||
if has_violation:
|
||||
print(
|
||||
f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - 类型: {type}, 详情: {data}")
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - "
|
||||
f"类型: {detector_type}, 详情: {data}")
|
||||
|
||||
# 调用违规次数加一方法
|
||||
try:
|
||||
@ -128,8 +151,12 @@ class ClientConnection:
|
||||
danger_msg = {
|
||||
"type": "danger",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": self.client_ip
|
||||
"client_ip": self.client_ip,
|
||||
"detector_type": detector_type,
|
||||
"detail": str(data)
|
||||
}
|
||||
|
||||
# TODO 数据存储到数据库
|
||||
await self.websocket.send_json(danger_msg)
|
||||
else:
|
||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 未检测到违规")
|
||||
@ -232,7 +259,7 @@ ws_router = APIRouter()
|
||||
|
||||
@ws_router.websocket(WS_ENDPOINT)
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
# 加载模型
|
||||
# 加载模型(首次连接时自动加载,线程安全)
|
||||
load_model()
|
||||
await websocket.accept()
|
||||
client_ip = websocket.client.host if websocket.client else "unknown_ip"
|
||||
@ -242,7 +269,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
is_online_updated = False
|
||||
|
||||
try:
|
||||
# 处理重复连接
|
||||
# 处理重复连接(关闭同一IP的旧连接)
|
||||
if client_ip in connected_clients:
|
||||
old_conn = connected_clients[client_ip]
|
||||
if old_conn.consumer_task and not old_conn.consumer_task.done():
|
||||
@ -255,7 +282,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
new_conn = ClientConnection(websocket, client_ip)
|
||||
connected_clients[client_ip] = new_conn
|
||||
new_conn.start_consumer()
|
||||
# 初始许可: 连接建立后立即发一次、让客户端知道可发第一帧(后续靠取帧后自动发)
|
||||
# 初始许可: 连接建立后立即发一次、让客户端知道可发第一帧
|
||||
await new_conn.send_frame_permit()
|
||||
|
||||
# 标记上线并记录
|
||||
@ -270,7 +297,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}")
|
||||
|
||||
# 消息循环
|
||||
# 消息循环(接收客户端文本/二进制消息)
|
||||
while True:
|
||||
data = await websocket.receive()
|
||||
if "text" in data:
|
||||
@ -289,7 +316,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
if conn.consumer_task and not conn.consumer_task.done():
|
||||
conn.consumer_task.cancel()
|
||||
|
||||
# 主动/异常断开时标记离线
|
||||
# 主动/异常断开时标记离线(仅当上线状态更新成功时)
|
||||
if is_online_updated:
|
||||
try:
|
||||
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
|
||||
@ -300,4 +327,4 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后离线更新失败 - {str(e)}")
|
||||
|
||||
connected_clients.pop(client_ip, None)
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理、在线数: {len(connected_clients)}")
|
||||
print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理、在线数: {len(connected_clients)}")
|
Reference in New Issue
Block a user