识别结果保存到对应目录下

This commit is contained in:
2025-09-09 16:30:12 +08:00
parent 0fe49bf829
commit 532a9e75e9
6 changed files with 375 additions and 325 deletions

View File

@ -3,12 +3,11 @@ import datetime
import json
import os
from contextlib import asynccontextmanager
from typing import Dict, Optional, AsyncGenerator
from typing import Dict, Optional
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
from service.device_action_service import add_device_action
from schema.device_action_schema import DeviceActionCreate
# 【修改1导入detect和TIMEOUT用于检测超时控制
from core.all import detect, load_model, TIMEOUT
from core.all import detect, load_model
import cv2
import numpy as np
@ -21,7 +20,7 @@ WS_ENDPOINT = "/ws" # WebSocket端点路径
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
# 工具函数: 获取格式化时间字符串(统一时间戳格式)
# 工具函数: 获取格式化时间字符串
def get_current_time_str() -> str:
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@ -40,13 +39,13 @@ class ClientConnection:
self.consumer_task: Optional[asyncio.Task] = None
def update_heartbeat(self):
"""更新心跳时间(客户端发送心跳时调用)"""
"""更新心跳时间"""
self.last_heartbeat = datetime.datetime.now()
def is_alive(self) -> bool:
"""判断客户端是否存活(心跳超时检查)"""
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
return timeout < HEARTBEAT_TIMEOUT
"""判断客户端是否存活"""
timeout_seconds = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
return timeout_seconds < HEARTBEAT_TIMEOUT
def start_consumer(self):
"""启动帧消费任务"""
@ -54,10 +53,7 @@ class ClientConnection:
return self.consumer_task
async def send_frame_permit(self):
"""
发送「帧发送许可信号」
通知客户端可发送下一帧图像
"""
"""发送帧发送许可信号"""
try:
frame_permit_msg = {
"type": "frame",
@ -65,26 +61,21 @@ class ClientConnection:
"client_ip": self.client_ip
}
await self.websocket.send_json(frame_permit_msg)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送帧发送许可信号(取帧后立即通知)")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送帧发送许可信号")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可信号发送失败 - {str(e)}")
async def consume_frames(self) -> None:
"""消费队列中的帧并处理(核心调整: 取帧后立即发许可、再处理帧)"""
"""消费队列中的帧并处理"""
try:
while True:
# 1. 从队列取出帧(阻塞直到有帧可用)
# 取出帧并立即发送下一帧许可
frame_data = await self.frame_queue.get()
# -------------------------- 核心修改: 取出帧后立即发送下一帧许可 --------------------------
await self.send_frame_permit() # 取帧即通知客户端发下一帧、无需等处理完成
# -----------------------------------------------------------------------------------------
await self.send_frame_permit()
try:
# 2. 处理取出的帧(即使处理慢、客户端也已收到许可、可提前准备下一帧)
await self.process_frame(frame_data)
finally:
# 3. 标记帧任务完成(无论处理成功/失败、都需清理队列)
self.frame_queue.task_done()
except asyncio.CancelledError:
@ -93,8 +84,8 @@ class ClientConnection:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧图像数据"""
# 二进制数据转OpenCV图像
"""处理单帧图像数据核心修复按3个返回值解包"""
# 二进制转OpenCV图像
nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
@ -102,52 +93,41 @@ class ClientConnection:
return
try:
# -------------------------- 提交检测任务并等待结果 --------------------------
# 1. 提交检测任务获取Future对象非阻塞
detection_future = detect(img)
# 2. 用asyncio.to_thread等待Future结果避免阻塞asyncio事件循环设置超时
try:
# 解包4元素结果(是否违规, 结果数据, 检测器类型, 任务ID)
has_violation, data, detector_type, task_id = await asyncio.to_thread(
detection_future.result, # 调用Future的result()获取实际结果
timeout=TIMEOUT # 超时控制与all.py配置一致
)
except TimeoutError:
# 处理检测超时场景
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测任务超时(超过{TIMEOUT}秒)")
has_violation = False
data = f"检测超时(超过{TIMEOUT}秒)"
detector_type = "timeout"
task_id = -1 # 超时任务ID标记为-1
# -----------------------------------------------------------------------------------------
# -------------------------- 修复核心匹配detect返回的3个值 --------------------------
# 假设detect返回 (是否违规, 结果数据, 检测器类型)
has_violation, data, detector_type = await asyncio.to_thread(
detect, # 调用检测函数
img # 传入图像参数
)
# -------------------------------------------------------------------------------------
# 打印检测结果
# 打印检测结果移除task_id相关内容
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测结果 - "
f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}, 任务ID: {task_id}")
f"违规: {has_violation}, 类型: {detector_type}, 数据: {data}")
# 处理违规逻辑
if has_violation:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - "
f"类型: {detector_type}, 详情: {data}")
# 调用违规次数加一方法
# 违规次数+1
try:
await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数已+1")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}")
# 发送危险通知
# 发送危险通知
danger_msg = {
"type": "danger",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip
"client_ip": self.client_ip,
"detail": data
}
# TODO 数据存储到数据库
await self.websocket.send_json(danger_msg)
else:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 未检测到违规")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}")
@ -157,7 +137,7 @@ connected_clients: Dict[str, ClientConnection] = {}
heartbeat_task: Optional[asyncio.Task] = None
# 心跳检查(定时清理超时客户端 + 调用离线状态更新方法)
# 心跳检查任务
async def heartbeat_checker():
while True:
current_time = get_current_time_str()
@ -172,7 +152,7 @@ async def heartbeat_checker():
conn.consumer_task.cancel()
await conn.websocket.close(code=1008, reason="心跳超时")
# 超时设为离线并记录
# 标记离线
try:
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
action_data = DeviceActionCreate(client_ip=ip, action=0)
@ -247,7 +227,6 @@ ws_router = APIRouter()
@ws_router.websocket(WS_ENDPOINT)
async def websocket_endpoint(websocket: WebSocket):
# 加载模型(首次连接时自动加载,线程安全)
load_model()
await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown_ip"
@ -257,7 +236,7 @@ async def websocket_endpoint(websocket: WebSocket):
is_online_updated = False
try:
# 处理重复连接关闭同一IP的旧连接
# 处理重复连接
if client_ip in connected_clients:
old_conn = connected_clients[client_ip]
if old_conn.consumer_task and not old_conn.consumer_task.done():
@ -270,10 +249,9 @@ async def websocket_endpoint(websocket: WebSocket):
new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn
new_conn.start_consumer()
# 初始许可: 连接建立后立即发一次、让客户端知道可发第一帧
await new_conn.send_frame_permit()
# 标记上线并记录
# 标记上线
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
action_data = DeviceActionCreate(client_ip=client_ip, action=1)
@ -285,7 +263,7 @@ async def websocket_endpoint(websocket: WebSocket):
print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}")
# 消息循环(接收客户端文本/二进制消息)
# 消息循环
while True:
data = await websocket.receive()
if "text" in data:
@ -298,13 +276,12 @@ async def websocket_endpoint(websocket: WebSocket):
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}")
finally:
# 清理资源并标记离线
# 清理资源
if client_ip in connected_clients:
conn = connected_clients[client_ip]
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
# 主动/异常断开时标记离线(仅当上线状态更新成功时)
if is_online_updated:
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)