修改WS兼容检测的Future对象

This commit is contained in:
2025-09-08 18:10:49 +08:00
parent 8ceb92c572
commit 1dd832e18d
2 changed files with 52 additions and 26 deletions

View File

@ -9,7 +9,7 @@ import numpy as np
# -------------------------- 核心配置参数 -------------------------- # -------------------------- 核心配置参数 --------------------------
MAX_WORKERS = 6 # 线程池最大线程数 MAX_WORKERS = 6 # 线程池最大线程数
DETECTION_ORDER = ["yolo", "face", "ocr"] # 检测优先级顺序 DETECTION_ORDER = ["yolo", "face", "ocr"] # 检测优先级顺序
TIMEOUT = 30 # 检测超时时间(秒) TIMEOUT = 30 # 检测超时时间(秒) 【确保此常量可被外部导入】
# -------------------------- 全局状态管理 -------------------------- # -------------------------- 全局状态管理 --------------------------
_executor = None # 线程池实例 _executor = None # 线程池实例
@ -80,7 +80,7 @@ def shutdown():
# -------------------------- 检测逻辑实现 -------------------------- # -------------------------- 检测逻辑实现 --------------------------
def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple: def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple:
"""在子线程中执行检测逻辑""" """在子线程中执行检测逻辑返回4元素tuple是否成功、结果、检测器类型、任务ID"""
thread_name = threading.current_thread().name thread_name = threading.current_thread().name
print(f"任务[{task_id}] 开始执行、线程: {thread_name}") print(f"任务[{task_id}] 开始执行、线程: {thread_name}")
@ -99,27 +99,27 @@ def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple:
print(f"任务[{task_id}] {detector}检测结果: {'成功' if flag else '失败'}") print(f"任务[{task_id}] {detector}检测结果: {'成功' if flag else '失败'}")
if flag: if flag:
print(f"任务[{task_id}] 完成检测、使用检测器: {detector}") print(f"任务[{task_id}] 完成检测、使用检测器: {detector}")
return (True, result, detector, task_id) return (True, result, detector, task_id) # 4元素tuple
# 所有检测器均未检测到结果 # 所有检测器均未检测到结果
print(f"任务[{task_id}] 所有检测器均未检测到内容") print(f"任务[{task_id}] 所有检测器均未检测到内容")
return (False, "未检测到任何内容", "none", task_id) return (False, "未检测到任何内容", "none", task_id) # 4元素tuple
except Exception as e: except Exception as e:
print(f"任务[{task_id}] 检测过程发生错误: {str(e)}") print(f"任务[{task_id}] 检测过程发生错误: {str(e)}")
return (False, f"检测错误: {str(e)}", "error", task_id) return (False, f"检测错误: {str(e)}", "error", task_id) # 4元素tuple
# -------------------------- 外部调用接口 -------------------------- # -------------------------- 外部调用接口 --------------------------
def detect(frame: np.ndarray) -> Future: def detect(frame: np.ndarray) -> Future:
""" """
提交检测任务到线程池 提交检测任务到线程池返回Future对象需调用result()获取4元素结果
参数: 参数:
frame: 待检测图像(ndarray格式、cv2.imdecode生成) frame: 待检测图像(ndarray格式、cv2.imdecode生成)
返回: 返回:
Future对象、通过result()方法获取检测结果 Future对象、result()返回tuple: (has_violation, data, detector_type, task_id)
""" """
# 确保模型已加载 # 确保模型已加载
if not _model_loaded: if not _model_loaded:
@ -129,8 +129,7 @@ def detect(frame: np.ndarray) -> Future:
# 生成任务ID # 生成任务ID
task_id = _get_next_task_id() task_id = _get_next_task_id()
# 提交任务到线程池 # 提交任务到线程池返回Future
future = _executor.submit(_detect_in_thread, frame, task_id) future = _executor.submit(_detect_in_thread, frame, task_id)
print(f"任务[{task_id}]: 已提交到线程池") print(f"任务[{task_id}]: 已提交到线程池")
return future return future

View File

@ -7,12 +7,12 @@ from typing import Dict, Optional, AsyncGenerator
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
from service.device_action_service import add_device_action from service.device_action_service import add_device_action
from schema.device_action_schema import DeviceActionCreate from schema.device_action_schema import DeviceActionCreate
from core.all import detect # 【修改1导入detect和TIMEOUT用于检测超时控制
from core.all import detect, load_model, TIMEOUT
import cv2 import cv2
import numpy as np import numpy as np
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
from core.all import load_model
# 配置常量 # 配置常量
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒) HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
@ -93,7 +93,7 @@ class ClientConnection:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None: async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧图像数据(检测违规后发送危险通知 + 调用违规次数加一方法""" """处理单帧图像数据(【核心修改:等待检测结果+修正解包】"""
# 二进制数据转OpenCV图像 # 二进制数据转OpenCV图像
nparr = np.frombuffer(frame_data, np.uint8) nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
@ -109,13 +109,36 @@ class ClientConnection:
try: try:
cv2.imwrite(filename, img) cv2.imwrite(filename, img)
print(f"[{get_current_time_str()}] 图像已保存至: {filename}") print(f"[{get_current_time_str()}] 图像已保存至: {filename}")
has_violation, data, type = detect(img)
print(has_violation) # -------------------------- 【核心修改1提交检测任务并等待结果】 --------------------------
print(type) # 1. 提交检测任务获取Future对象非阻塞
print(data) detection_future = detect(img)
# 2. 用asyncio.to_thread等待Future结果避免阻塞asyncio事件循环设置超时
try:
# 解包4元素结果(是否违规, 结果数据, 检测器类型, 任务ID)
has_violation, data, detector_type, task_id = await asyncio.to_thread(
detection_future.result, # 调用Future的result()获取实际结果
timeout=TIMEOUT # 超时控制与all.py配置一致
)
except TimeoutError:
# 处理检测超时场景
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测任务超时(超过{TIMEOUT}秒)")
has_violation = False
data = f"检测超时(超过{TIMEOUT}秒)"
detector_type = "timeout"
task_id = -1 # 超时任务ID标记为-1
# -----------------------------------------------------------------------------------------
# -------------------------- 【核心修改2修正日志打印变量名】 --------------------------
# 打印检测结果避免使用Python关键字"type"
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - "
f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}, 任务ID: {task_id}")
# -----------------------------------------------------------------------------------------
# 处理违规逻辑变量名从type改为detector_type
if has_violation: if has_violation:
print( print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - "
f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - 类型: {type}, 详情: {data}") f"类型: {detector_type}, 详情: {data}")
# 调用违规次数加一方法 # 调用违规次数加一方法
try: try:
@ -128,8 +151,12 @@ class ClientConnection:
danger_msg = { danger_msg = {
"type": "danger", "type": "danger",
"timestamp": get_current_time_str(), "timestamp": get_current_time_str(),
"client_ip": self.client_ip "client_ip": self.client_ip,
"detector_type": detector_type,
"detail": str(data)
} }
# TODO 数据存储到数据库
await self.websocket.send_json(danger_msg) await self.websocket.send_json(danger_msg)
else: else:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 未检测到违规") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 未检测到违规")
@ -232,7 +259,7 @@ ws_router = APIRouter()
@ws_router.websocket(WS_ENDPOINT) @ws_router.websocket(WS_ENDPOINT)
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
# 加载模型 # 加载模型(首次连接时自动加载,线程安全)
load_model() load_model()
await websocket.accept() await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown_ip" client_ip = websocket.client.host if websocket.client else "unknown_ip"
@ -242,7 +269,7 @@ async def websocket_endpoint(websocket: WebSocket):
is_online_updated = False is_online_updated = False
try: try:
# 处理重复连接 # 处理重复连接关闭同一IP的旧连接
if client_ip in connected_clients: if client_ip in connected_clients:
old_conn = connected_clients[client_ip] old_conn = connected_clients[client_ip]
if old_conn.consumer_task and not old_conn.consumer_task.done(): if old_conn.consumer_task and not old_conn.consumer_task.done():
@ -255,7 +282,7 @@ async def websocket_endpoint(websocket: WebSocket):
new_conn = ClientConnection(websocket, client_ip) new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn connected_clients[client_ip] = new_conn
new_conn.start_consumer() new_conn.start_consumer()
# 初始许可: 连接建立后立即发一次、让客户端知道可发第一帧(后续靠取帧后自动发) # 初始许可: 连接建立后立即发一次、让客户端知道可发第一帧
await new_conn.send_frame_permit() await new_conn.send_frame_permit()
# 标记上线并记录 # 标记上线并记录
@ -270,7 +297,7 @@ async def websocket_endpoint(websocket: WebSocket):
print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}") print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}")
# 消息循环 # 消息循环(接收客户端文本/二进制消息)
while True: while True:
data = await websocket.receive() data = await websocket.receive()
if "text" in data: if "text" in data:
@ -289,7 +316,7 @@ async def websocket_endpoint(websocket: WebSocket):
if conn.consumer_task and not conn.consumer_task.done(): if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel() conn.consumer_task.cancel()
# 主动/异常断开时标记离线 # 主动/异常断开时标记离线(仅当上线状态更新成功时)
if is_online_updated: if is_online_updated:
try: try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0) await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
@ -300,4 +327,4 @@ async def websocket_endpoint(websocket: WebSocket):
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后离线更新失败 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后离线更新失败 - {str(e)}")
connected_clients.pop(client_ip, None) connected_clients.pop(client_ip, None)
print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理、在线数: {len(connected_clients)}") print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理、在线数: {len(connected_clients)}")