修改WS兼容检测的Future对象

This commit is contained in:
2025-09-08 18:10:49 +08:00
parent 8ceb92c572
commit 1dd832e18d
2 changed files with 52 additions and 26 deletions

View File

@ -9,7 +9,7 @@ import numpy as np
# -------------------------- 核心配置参数 --------------------------
MAX_WORKERS = 6 # 线程池最大线程数
DETECTION_ORDER = ["yolo", "face", "ocr"] # 检测优先级顺序
TIMEOUT = 30 # 检测超时时间(秒)
TIMEOUT = 30 # 检测超时时间(秒) 【确保此常量可被外部导入】
# -------------------------- 全局状态管理 --------------------------
_executor = None # 线程池实例
@ -80,7 +80,7 @@ def shutdown():
# -------------------------- 检测逻辑实现 --------------------------
def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple:
"""在子线程中执行检测逻辑"""
"""在子线程中执行检测逻辑返回4元素tuple是否成功、结果、检测器类型、任务ID"""
thread_name = threading.current_thread().name
print(f"任务[{task_id}] 开始执行、线程: {thread_name}")
@ -99,27 +99,27 @@ def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple:
print(f"任务[{task_id}] {detector}检测结果: {'成功' if flag else '失败'}")
if flag:
print(f"任务[{task_id}] 完成检测、使用检测器: {detector}")
return (True, result, detector, task_id)
return (True, result, detector, task_id) # 4元素tuple
# 所有检测器均未检测到结果
print(f"任务[{task_id}] 所有检测器均未检测到内容")
return (False, "未检测到任何内容", "none", task_id)
return (False, "未检测到任何内容", "none", task_id) # 4元素tuple
except Exception as e:
print(f"任务[{task_id}] 检测过程发生错误: {str(e)}")
return (False, f"检测错误: {str(e)}", "error", task_id)
return (False, f"检测错误: {str(e)}", "error", task_id) # 4元素tuple
# -------------------------- 外部调用接口 --------------------------
def detect(frame: np.ndarray) -> Future:
"""
提交检测任务到线程池
提交检测任务到线程池返回Future对象需调用result()获取4元素结果
参数:
frame: 待检测图像(ndarray格式、cv2.imdecode生成)
返回:
Future对象、通过result()方法获取检测结果
Future对象、result()返回tuple: (has_violation, data, detector_type, task_id)
"""
# 确保模型已加载
if not _model_loaded:
@ -129,8 +129,7 @@ def detect(frame: np.ndarray) -> Future:
# 生成任务ID
task_id = _get_next_task_id()
# 提交任务到线程池
# 提交任务到线程池返回Future
future = _executor.submit(_detect_in_thread, frame, task_id)
print(f"任务[{task_id}]: 已提交到线程池")
return future

View File

@ -7,12 +7,12 @@ from typing import Dict, Optional, AsyncGenerator
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
from service.device_action_service import add_device_action
from schema.device_action_schema import DeviceActionCreate
from core.all import detect
# 【修改1导入detect和TIMEOUT用于检测超时控制
from core.all import detect, load_model, TIMEOUT
import cv2
import numpy as np
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
from core.all import load_model
# 配置常量
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
@ -93,7 +93,7 @@ class ClientConnection:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧图像数据(检测违规后发送危险通知 + 调用违规次数加一方法"""
"""处理单帧图像数据(【核心修改:等待检测结果+修正解包】"""
# 二进制数据转OpenCV图像
nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
@ -109,13 +109,36 @@ class ClientConnection:
try:
cv2.imwrite(filename, img)
print(f"[{get_current_time_str()}] 图像已保存至: {filename}")
has_violation, data, type = detect(img)
print(has_violation)
print(type)
print(data)
# -------------------------- 【核心修改1提交检测任务并等待结果】 --------------------------
# 1. 提交检测任务获取Future对象非阻塞
detection_future = detect(img)
# 2. 用asyncio.to_thread等待Future结果避免阻塞asyncio事件循环设置超时
try:
# 解包4元素结果(是否违规, 结果数据, 检测器类型, 任务ID)
has_violation, data, detector_type, task_id = await asyncio.to_thread(
detection_future.result, # 调用Future的result()获取实际结果
timeout=TIMEOUT # 超时控制与all.py配置一致
)
except TimeoutError:
# 处理检测超时场景
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测任务超时(超过{TIMEOUT}秒)")
has_violation = False
data = f"检测超时(超过{TIMEOUT}秒)"
detector_type = "timeout"
task_id = -1 # 超时任务ID标记为-1
# -----------------------------------------------------------------------------------------
# -------------------------- 【核心修改2修正日志打印变量名】 --------------------------
# 打印检测结果避免使用Python关键字"type"
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - "
f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}, 任务ID: {task_id}")
# -----------------------------------------------------------------------------------------
# 处理违规逻辑变量名从type改为detector_type
if has_violation:
print(
f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - 类型: {type}, 详情: {data}")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - "
f"类型: {detector_type}, 详情: {data}")
# 调用违规次数加一方法
try:
@ -128,8 +151,12 @@ class ClientConnection:
danger_msg = {
"type": "danger",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip
"client_ip": self.client_ip,
"detector_type": detector_type,
"detail": str(data)
}
# TODO 数据存储到数据库
await self.websocket.send_json(danger_msg)
else:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 未检测到违规")
@ -232,7 +259,7 @@ ws_router = APIRouter()
@ws_router.websocket(WS_ENDPOINT)
async def websocket_endpoint(websocket: WebSocket):
# 加载模型
# 加载模型(首次连接时自动加载,线程安全)
load_model()
await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown_ip"
@ -242,7 +269,7 @@ async def websocket_endpoint(websocket: WebSocket):
is_online_updated = False
try:
# 处理重复连接
# 处理重复连接关闭同一IP的旧连接
if client_ip in connected_clients:
old_conn = connected_clients[client_ip]
if old_conn.consumer_task and not old_conn.consumer_task.done():
@ -255,7 +282,7 @@ async def websocket_endpoint(websocket: WebSocket):
new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn
new_conn.start_consumer()
# 初始许可: 连接建立后立即发一次、让客户端知道可发第一帧(后续靠取帧后自动发)
# 初始许可: 连接建立后立即发一次、让客户端知道可发第一帧
await new_conn.send_frame_permit()
# 标记上线并记录
@ -270,7 +297,7 @@ async def websocket_endpoint(websocket: WebSocket):
print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}")
# 消息循环
# 消息循环(接收客户端文本/二进制消息)
while True:
data = await websocket.receive()
if "text" in data:
@ -289,7 +316,7 @@ async def websocket_endpoint(websocket: WebSocket):
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
# 主动/异常断开时标记离线
# 主动/异常断开时标记离线(仅当上线状态更新成功时)
if is_online_updated:
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)