This commit is contained in:
ZZX9599
2025-09-04 12:29:27 +08:00
parent ea82a33a8f
commit b5d870a19c
5 changed files with 461 additions and 339 deletions

View File

@ -0,0 +1,36 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel, Field
# ------------------------------
# 请求模型(新增记录用,极简)
# ------------------------------
class DeviceActionCreate(BaseModel):
"""设备操作记录创建模型0=离线1=上线)"""
client_ip: str = Field(..., description="客户端IP")
action: int = Field(..., ge=0, le=1, description="操作状态0=离线1=上线)")
# ------------------------------
# 响应模型(单条记录)
# ------------------------------
class DeviceActionResponse(BaseModel):
"""设备操作记录响应模型(与自增表对齐)"""
id: int = Field(..., description="自增主键ID")
client_ip: Optional[str] = Field(None, description="客户端IP")
action: Optional[int] = Field(None, description="操作状态0=离线1=上线)")
created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间")
# 支持从数据库结果直接转换
model_config = {"from_attributes": True}
# ------------------------------
# 列表响应模型(仅含 total + device_actions
# ------------------------------
class DeviceActionListResponse(BaseModel):
"""设备操作记录列表(仅核心返回字段)"""
total: int = Field(..., description="总记录数")
device_actions: List[DeviceActionResponse] = Field(..., description="操作记录列表")

View File

@ -1,4 +1,3 @@
import hashlib
from datetime import datetime from datetime import datetime
from typing import Optional, List, Dict from typing import Optional, List, Dict
@ -6,42 +5,31 @@ from pydantic import BaseModel, Field
# ------------------------------ # ------------------------------
# 请求模型(前端传参校验) # 请求模型
# ------------------------------ # ------------------------------
class DeviceCreateRequest(BaseModel): class DeviceCreateRequest(BaseModel):
"""设备流信息创建请求模型""" """设备流信息创建请求模型(与数据库表字段对齐)"""
ip: Optional[str] = Field(..., max_length=100, description="设备IP地址") ip: Optional[str] = Field(..., max_length=100, description="设备IP地址")
hostname: Optional[str] = Field(None, max_length=100, description="设备别名") hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
params: Optional[Dict] = Field(None, description="设备详细信息") params: Optional[Dict] = Field(None, description="设备详细信息JSON格式")
def md5_encrypt(text: str) -> str:
"""对字符串进行MD5加密"""
if not text:
return ""
md5_hash = hashlib.md5()
md5_hash.update(text.encode('utf-8'))
return md5_hash.hexdigest()
# ------------------------------ # ------------------------------
# 响应模型(后端返回设备数据) # 响应模型(后端返回数据)- 严格对齐数据库表字段
# ------------------------------ # ------------------------------
class DeviceResponse(BaseModel): class DeviceResponse(BaseModel):
"""设备流信息响应模型(字段与表结构完全对齐""" """设备流信息响应模型(与数据库表字段完全一致"""
id: int = Field(..., description="设备ID") id: int = Field(..., description="设备主键ID")
client_ip: Optional[str] = Field(None, max_length=100, description="设备IP地址")
hostname: Optional[str] = Field(None, max_length=100, description="设备别名") hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
rtmp_push_url: Optional[str] = Field(None, description="需要推送的RTMP地址")
live_webrtc_url: Optional[str] = Field(None, description="直播的Webrtc地址")
detection_webrtc_url: Optional[str] = Field(None, description="检测的Webrtc地址")
device_online_status: int = Field(..., description="设备在线状态1-在线、0-离线)") device_online_status: int = Field(..., description="设备在线状态1-在线、0-离线)")
device_type: Optional[str] = Field(None, description="设备类型") device_type: Optional[str] = Field(None, description="设备类型")
alarm_count: int = Field(..., description="报警次数") alarm_count: int = Field(..., description="报警次数")
params: Optional[str] = Field(None, description="设备详细信息") params: Optional[str] = Field(None, description="设备详细信息JSON字符串")
created_at: datetime = Field(..., description="记录创建时间") created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间") updated_at: datetime = Field(..., description="记录更新时间")
# 支持从数据库查询结果转换 # 支持从数据库查询结果直接转换
model_config = {"from_attributes": True} model_config = {"from_attributes": True}

View File

@ -0,0 +1,119 @@
from fastapi import APIRouter, Query
from mysql.connector import Error as MySQLError
from ds.db import db
from schema.device_action_schema import (
DeviceActionCreate,
DeviceActionResponse,
DeviceActionListResponse
)
from schema.response_schema import APIResponse
# 路由配置
router = APIRouter(
prefix="/device/actions",
tags=["设备操作记录"]
)
# ------------------------------
# 内部方法新增设备操作记录适配id自增
# ------------------------------
def add_device_action(action_data: DeviceActionCreate) -> DeviceActionResponse:
"""
新增设备操作记录(内部方法,非接口)
:param action_data: 含client_ip和action0/1
:return: 新增的完整记录
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 插入SQLid自增依赖数据库自动生成
insert_query = """
INSERT INTO device_action
(client_ip, action, created_at, updated_at)
VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
"""
cursor.execute(insert_query, (
action_data.client_ip,
action_data.action
))
conn.commit()
# 获取新增记录通过自增ID查询
new_id = cursor.lastrowid
cursor.execute("SELECT * FROM device_action WHERE id = %s", (new_id,))
new_action = cursor.fetchone()
return DeviceActionResponse(**new_action)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"新增记录失败:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 接口:分页查询操作记录列表(仅返回 total + device_actions
# ------------------------------
@router.get("/list", response_model=APIResponse, summary="分页查询设备操作记录")
async def get_device_action_list(
page: int = Query(1, ge=1, description="页码默认1"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100"),
client_ip: str = Query(None, description="按客户端IP筛选"),
action: int = Query(None, ge=0, le=1, description="按状态筛选0=离线1=上线)")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 构建筛选条件(参数化查询,避免注入)
where_clause = []
params = []
if client_ip:
where_clause.append("client_ip = %s")
params.append(client_ip)
if action is not None:
where_clause.append("action = %s")
params.append(action)
# 2. 查询总记录数(用于返回 total
count_sql = "SELECT COUNT(*) AS total FROM device_action"
if where_clause:
count_sql += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_sql, params)
total = cursor.fetchone()["total"]
# 3. 分页查询记录(按创建时间倒序,确保最新记录在前)
offset = (page - 1) * page_size
list_sql = "SELECT * FROM device_action"
if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause)
list_sql += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset]) # 追加分页参数page/page_size仅用于查询不返回
cursor.execute(list_sql, params)
action_list = cursor.fetchall()
# 4. 仅返回 total + device_actions
return APIResponse(
code=200,
message="查询成功",
data=DeviceActionListResponse(
total=total,
device_actions=[DeviceActionResponse(**item) for item in action_list]
)
)
except MySQLError as e:
raise Exception(f"查询记录失败:{str(e)}") from e
finally:
db.close_connection(conn, cursor)

View File

@ -1,25 +1,11 @@
import json import json
import threading
import time
from fastapi import HTTPException, Query, APIRouter, Depends, Request from fastapi import APIRouter, Query, HTTPException
from mysql.connector import Error as MySQLError from mysql.connector import Error as MySQLError
from ds.config import LIVE_CONFIG
from ds.db import db from ds.db import db
from middle.auth_middleware import get_current_user from schema.device_schema import DeviceCreateRequest, DeviceResponse, DeviceListResponse
# 注意导入的Schema已更新字段
from schema.device_schema import (
DeviceCreateRequest,
DeviceResponse,
DeviceListResponse,
md5_encrypt
)
from schema.response_schema import APIResponse from schema.response_schema import APIResponse
from schema.user_schema import UserResponse
# 导入之前封装的WEBRTC处理函数
from core.rtmp import rtmp_pull_video_stream
router = APIRouter( router = APIRouter(
prefix="/devices", prefix="/devices",
@ -27,65 +13,128 @@ router = APIRouter(
) )
# 在后台线程中运行WEBRTC处理
def run_webrtc_processing(ip, webrtc_url):
try:
print(f"开始处理来自设备 {ip} 的WEBRTC流: {webrtc_url}")
rtmp_pull_video_stream(webrtc_url)
except Exception as e:
print(f"WEBRTC处理出错: {str(e)}")
# ------------------------------ # ------------------------------
# 1. 创建设备信息 # 内部工具方法 - 通过客户端IP增加设备报警次数
# ------------------------------ # ------------------------------
@router.post("/add", response_model=APIResponse, summary="创建设备信息") def increment_alarm_count_by_ip(client_ip: str) -> bool:
async def create_device(request: Request, device_data: DeviceCreateRequest): """
通过客户端IP增加设备的报警次数内部服务方法
:param client_ip: 客户端IP地址
:return: 操作是否成功
"""
if not client_ip:
raise ValueError("客户端IP不能为空")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查设备是否存在
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
if not cursor.fetchone():
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
# 报警次数加1并更新时间戳
update_query = """
UPDATE devices
SET alarm_count = alarm_count + 1,
updated_at = CURRENT_TIMESTAMP
WHERE client_ip = %s
"""
cursor.execute(update_query, (client_ip,))
conn.commit()
return True
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"更新报警次数失败:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 内部工具方法 - 通过客户端IP更新设备在线状态
# ------------------------------
def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
"""
通过客户端IP更新设备的在线状态内部服务方法
:param client_ip: 客户端IP地址
:param online_status: 在线状态1-在线、0-离线)
:return: 操作是否成功
"""
if not client_ip:
raise ValueError("客户端IP不能为空")
# 验证状态值有效性
if online_status not in (0, 1):
raise ValueError("在线状态必须是0离线或1在线")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查设备是否存在
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
if not cursor.fetchone():
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
# 更新在线状态和时间戳
update_query = """
UPDATE devices
SET device_online_status = %s,
updated_at = CURRENT_TIMESTAMP
WHERE client_ip = %s
"""
cursor.execute(update_query, (online_status, client_ip))
conn.commit()
return True
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"更新设备在线状态失败:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 原有接口保持不变
# ------------------------------
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
async def create_device(device_data: DeviceCreateRequest):
# 原有代码保持不变
conn = None conn = None
cursor = None cursor = None
try: try:
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 检查client_ip是否已存在
cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,)) cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,))
existing_device = cursor.fetchone() existing_device = cursor.fetchone()
if existing_device: if existing_device:
# 设备创建成功后在后台线程启动WEBRTC流处理 # 更新设备状态为在线
threading.Thread( update_online_status_by_ip(client_ip=device_data.ip, online_status=1)
target=run_webrtc_processing, # 返回信息
# args=(device_data.ip, existing_device["live_webrtc_url"]),
args=(device_data.ip, existing_device["rtmp_push_url"]),
daemon=True # 设为守护线程,主程序退出时自动结束
).start()
# IP已存在时返回该设备信息
return APIResponse( return APIResponse(
code=200, code=200,
message=f"客户端IP {device_data.ip} 已存在", message=f"设备IP {device_data.ip} 已存在,返回已有设备信息",
data=DeviceResponse(**existing_device) data=DeviceResponse(**existing_device)
) )
# 获取RTMP URL和WEBRTC URL配置 from fastapi import Request
rtmp_url = str(LIVE_CONFIG.get("rtmp_url", "")) request = Request(scope={"type": "http"})
webrtc_url = str(LIVE_CONFIG.get("webrtc_url", ""))
# 将设备详细信息params转换为JSON字符串
device_params_json = json.dumps(device_data.params) if device_data.params else None
# 对JSON字符串进行MD5加密
device_md5 = md5_encrypt(device_params_json) if device_params_json else ""
# 解析User-Agent获取设备类型
user_agent = request.headers.get("User-Agent", "").lower() user_agent = request.headers.get("User-Agent", "").lower()
# 优先处理User-Agent为default的情况
if user_agent == "default": if user_agent == "default":
# 检查params中是否存在os键 device_type = device_data.params.get("os") if (
if device_data.params and isinstance(device_data.params, dict) and "os" in device_data.params: device_data.params and isinstance(device_data.params, dict)) else "unknown"
device_type = device_data.params["os"]
else:
device_type = "unknown"
elif "windows" in user_agent: elif "windows" in user_agent:
device_type = "windows" device_type = "windows"
elif "android" in user_agent: elif "android" in user_agent:
@ -95,22 +144,16 @@ async def create_device(request: Request, device_data: DeviceCreateRequest):
else: else:
device_type = "unknown" device_type = "unknown"
# 构建完整的WEBRTC URL device_params_json = json.dumps(device_data.params) if device_data.params else None
full_webrtc_url = webrtc_url + device_md5
# SQL插入语句
insert_query = """ insert_query = """
INSERT INTO devices INSERT INTO devices
(client_ip, hostname, rtmp_push_url, live_webrtc_url, detection_webrtc_url, (client_ip, hostname, device_online_status, device_type, alarm_count, params)
device_online_status, device_type, alarm_count, params) VALUES (%s, %s, %s, %s, %s, %s)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
""" """
cursor.execute(insert_query, ( cursor.execute(insert_query, (
device_data.ip, device_data.ip,
device_data.hostname, device_data.hostname,
rtmp_url + device_md5,
full_webrtc_url, # 存储完整的WEBRTC URL
"",
1, 1,
device_type, device_type,
0, 0,
@ -118,28 +161,22 @@ async def create_device(request: Request, device_data: DeviceCreateRequest):
)) ))
conn.commit() conn.commit()
# 获取刚创建的设备信息
device_id = cursor.lastrowid device_id = cursor.lastrowid
cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,)) cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,))
device = cursor.fetchone() new_device = cursor.fetchone()
# 设备创建成功后在后台线程启动WEBRTC流处理
threading.Thread(
target=run_webrtc_processing,
args=(device_data.ip, full_webrtc_url),
daemon=True # 设为守护线程,主程序退出时自动结束
).start()
return APIResponse( return APIResponse(
code=200, code=200,
message="设备创建成功已开始处理WEBRTC流", message="设备创建成功",
data=DeviceResponse(**device) data=DeviceResponse(**new_device)
) )
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"创建设备失败:{str(e)}") from e raise Exception(f"创建设备失败:{str(e)}") from e
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise Exception(f"设备信息JSON序列化失败{str(e)}") from e raise Exception(f"设备详细信息JSON序列化失败{str(e)}") from e
except Exception as e: except Exception as e:
if conn: if conn:
conn.rollback() conn.rollback()
@ -148,139 +185,56 @@ async def create_device(request: Request, device_data: DeviceCreateRequest):
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# ------------------------------ @router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
# 2. 获取设备列表
# ------------------------------
@router.get("/", response_model=APIResponse, summary="获取设备列表")
async def get_device_list( async def get_device_list(
page: int = Query(1, ge=1, description="页码"), page: int = Query(1, ge=1, description="页码默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数"), page_size: int = Query(10, ge=1, le=100, description="每页条数1-100之间"),
device_type: str = Query(None, description="设备类型筛选"), device_type: str = Query(None, description="设备类型筛选"),
online_status: int = Query(None, ge=0, le=1, description="在线状态筛选1-在线、0-离线)") online_status: int = Query(None, ge=0, le=1, description="在线状态筛选")
): ):
# 原有代码保持不变
conn = None conn = None
cursor = None cursor = None
try: try:
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 构建查询条件
where_clause = [] where_clause = []
params = [] params = []
if device_type: if device_type:
where_clause.append("device_type = %s") where_clause.append("device_type = %s")
params.append(device_type) params.append(device_type)
if online_status is not None: if online_status is not None:
where_clause.append("device_online_status = %s") where_clause.append("device_online_status = %s")
params.append(online_status) params.append(online_status)
# 总条数查询 count_query = "SELECT COUNT(*) AS total FROM devices"
count_query = "SELECT COUNT(*) as total FROM devices"
if where_clause: if where_clause:
count_query += " WHERE " + " AND ".join(where_clause) count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params) cursor.execute(count_query, params)
total = cursor.fetchone()["total"] total = cursor.fetchone()["total"]
# 分页查询SELECT * 会自动匹配表字段、响应模型已对齐)
offset = (page - 1) * page_size offset = (page - 1) * page_size
query = "SELECT * FROM devices" list_query = "SELECT * FROM devices"
if where_clause: if where_clause:
query += " WHERE " + " AND ".join(where_clause) list_query += " WHERE " + " AND ".join(where_clause)
query += " ORDER BY id DESC LIMIT %s OFFSET %s" list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset]) params.extend([page_size, offset])
cursor.execute(query, params) cursor.execute(list_query, params)
devices = cursor.fetchall() device_list = cursor.fetchall()
# 响应模型已更新为params字段、直接转换即可
device_list = [DeviceResponse(**device) for device in devices]
return APIResponse( return APIResponse(
code=200, code=200,
message="获取设备列表成功", message="获取设备列表成功",
data=DeviceListResponse(total=total, devices=device_list) data=DeviceListResponse(
total=total,
devices=[DeviceResponse(**device) for device in device_list]
)
) )
except MySQLError as e: except MySQLError as e:
raise Exception(f"获取设备列表失败:{str(e)}") from e raise Exception(f"获取设备列表失败:{str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# ------------------------------
# 3. 获取单个设备详情
# ------------------------------
@router.get("/{device_id}", response_model=APIResponse, summary="获取设备详情")
async def get_device_detail(
device_id: int,
current_user: UserResponse = Depends(get_current_user)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 查询设备信息SELECT * 匹配表字段)
query = "SELECT * FROM devices WHERE id = %s"
cursor.execute(query, (device_id,))
device = cursor.fetchone()
if not device:
raise HTTPException(
status_code=404,
detail=f"设备ID为 {device_id} 的设备不存在"
)
# 响应模型已更新为params字段
return APIResponse(
code=200,
message="获取设备详情成功",
data=DeviceResponse(**device)
)
except MySQLError as e:
raise Exception(f"获取设备详情失败:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 4. 删除设备信息
# ------------------------------
@router.delete("/{device_id}", response_model=APIResponse, summary="删除设备信息")
async def delete_device(
device_id: int,
current_user: UserResponse = Depends(get_current_user)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查设备是否存在
cursor.execute("SELECT id FROM devices WHERE id = %s", (device_id,))
if not cursor.fetchone():
raise HTTPException(
status_code=404,
detail=f"设备ID为 {device_id} 的设备不存在"
)
# 执行删除
delete_query = "DELETE FROM devices WHERE id = %s"
cursor.execute(delete_query, (device_id,))
conn.commit()
return APIResponse(
code=200,
message=f"设备ID为 {device_id} 的设备已成功删除",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"删除设备失败:{str(e)}") from e
finally:
db.close_connection(conn, cursor)

309
ws/ws.py
View File

@ -4,6 +4,9 @@ import json
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Dict, Optional, AsyncGenerator from typing import Dict, Optional, AsyncGenerator
from service.device_service import update_online_status_by_ip, increment_alarm_count_by_ip
from service.device_action_service import add_device_action
from schema.device_action_schema import DeviceActionCreate
import cv2 import cv2
import numpy as np import numpy as np
@ -11,7 +14,7 @@ from fastapi import WebSocket, APIRouter, WebSocketDisconnect, FastAPI
from ocr.model_violation_detector import MultiModelViolationDetector from ocr.model_violation_detector import MultiModelViolationDetector
# 配置文件相对路径(根据实际目录结构调整 # 配置文件路径(建议实际部署时改为相对路径或环境变量
YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt" YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt" FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt"
OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml" OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml"
@ -26,265 +29,287 @@ detector = MultiModelViolationDetector(
ocr_confidence_threshold=0.5 ocr_confidence_threshold=0.5
) )
# -------------------------- 配置常量 -------------------------- # 配置常量
HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒) HEARTBEAT_INTERVAL = 30 # 心跳检查间隔(秒)
HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒) HEARTBEAT_TIMEOUT = 600 # 客户端超时阈值(秒)
WS_ENDPOINT = "/ws" # WebSocket端点路径 WS_ENDPOINT = "/ws" # WebSocket端点路径
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制保持1确保单帧处理 FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
# -------------------------- 核心数据结构与全局变量 --------------------------
ws_router = APIRouter()
# 客户端连接封装(包含帧队列 # 工具函数:获取格式化时间字符串(统一时间戳格式
def get_current_time_str() -> str:
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def get_current_time_file_str() -> str:
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
# 客户端连接封装
class ClientConnection: class ClientConnection:
def __init__(self, websocket: WebSocket, client_ip: str): def __init__(self, websocket: WebSocket, client_ip: str):
self.websocket = websocket self.websocket = websocket
self.client_ip = client_ip self.client_ip = client_ip
self.last_heartbeat = datetime.datetime.now() self.last_heartbeat = datetime.datetime.now()
self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE) # 帧队列长度为1 self.frame_queue = asyncio.Queue(maxsize=FRAME_QUEUE_SIZE)
self.consumer_task: Optional[asyncio.Task] = None # 消费者任务 self.consumer_task: Optional[asyncio.Task] = None
# 更新心跳时间
def update_heartbeat(self): def update_heartbeat(self):
"""更新心跳时间(客户端发送心跳时调用)"""
self.last_heartbeat = datetime.datetime.now() self.last_heartbeat = datetime.datetime.now()
# 检查是否存活超时返回False
def is_alive(self) -> bool: def is_alive(self) -> bool:
"""判断客户端是否存活(心跳超时检查)"""
timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds() timeout = (datetime.datetime.now() - self.last_heartbeat).total_seconds()
return timeout < HEARTBEAT_TIMEOUT return timeout < HEARTBEAT_TIMEOUT
# 启动帧消费任务
def start_consumer(self): def start_consumer(self):
"""启动帧消费任务"""
self.consumer_task = asyncio.create_task(self.consume_frames()) self.consumer_task = asyncio.create_task(self.consume_frames())
return self.consumer_task return self.consumer_task
# ---------- 新增:发送“允许发送二进制帧”的信号给客户端 ---------- async def send_frame_permit(self):
async def send_allow_send_frame(self): """
"""向客户端发送JSON信号通知其可发送下一帧二进制数据""" 发送「帧发送许可信号」
通知客户端可发送下一帧图像
"""
try: try:
allow_msg = { frame_permit_msg = {
"type": "allow_send_frame", # 信号类型,与客户端约定 "type": "frame",
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "timestamp": get_current_time_str(),
"status": "ready", # 表示服务器已准备好接收下一帧 "client_ip": self.client_ip
"client_ip": self.client_ip # 可选:便于客户端确认自身身份
} }
await self.websocket.send_json(allow_msg) await self.websocket.send_json(frame_permit_msg)
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:已发送「允许发送帧」信号") print(f"[{get_current_time_str()}] 客户端{self.client_ip}:已发送帧发送许可信号(取帧后立即通知)")
except Exception as e: except Exception as e:
# 发送失败大概率是客户端已断开,不影响主流程,仅日志记录 print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧许可信号发送失败 - {str(e)}")
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:发送「允许发送帧」信号失败 - {str(e)}")
# 帧消费协程
async def consume_frames(self) -> None: async def consume_frames(self) -> None:
"""队列中获取帧并进行处理,处理完后通知客户端可发送下一帧""" """消费队列中的帧并处理(核心调整:取帧后立即发许可,再处理帧)"""
try: try:
while True: while True:
# 从队列获取帧数据(队列空时会阻塞,等待客户端发送 # 1. 从队列取出帧(阻塞直到有帧可用
frame_data = await self.frame_queue.get() frame_data = await self.frame_queue.get()
# -------------------------- 核心修改:取出帧后立即发送下一帧许可 --------------------------
await self.send_frame_permit() # 取帧即通知客户端发下一帧,无需等处理完成
# -----------------------------------------------------------------------------------------
try: try:
# 处理帧数据 # 2. 处理取出的帧(即使处理慢,客户端也已收到许可,可提前准备下一帧)
await self.process_frame(frame_data) await self.process_frame(frame_data)
finally: finally:
# 标记任务完成(队列计数-1此时队列回到空状态 # 3. 标记任务完成(无论处理成功/失败,都需清理队列
self.frame_queue.task_done() self.frame_queue.task_done()
# ---------- 修改:处理完当前帧后,立即通知客户端可发送下一帧 ----------
await self.send_allow_send_frame()
except asyncio.CancelledError: except asyncio.CancelledError:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:帧消费任务已取消") print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费任务已取消")
except Exception as e: except Exception as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:帧处理错误 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费逻辑错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None: async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧图像数据(原有逻辑不变""" """处理单帧图像数据(检测违规后发送危险通知 + 调用违规次数加一方法"""
# 二进制数据转换为NumPy数组uint8类型 # 二进制数据转OpenCV图像
nparr = np.frombuffer(frame_data, np.uint8) nparr = np.frombuffer(frame_data, np.uint8)
# 解码为图像返回与cv2.imread相同的格式BGR通道的ndarray
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:无法解析图像数据")
return
# 确保images文件夹存在 # 确保图像保存目录存在
if not os.path.exists('images'): os.makedirs('images', exist_ok=True)
os.makedirs('images')
# 生成唯一的文件名包含时间戳和客户端IP避免文件名冲突
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"images/{self.client_ip.replace('.', '_')}_{timestamp}.jpg"
# 保存图像按IP+时间戳命名,避免冲突)
filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg"
try: try:
# 保存图像到本地
cv2.imwrite(filename, img) cv2.imwrite(filename, img)
print(f"[{datetime.datetime.now():%H:%M:%S}] 图像已保存至:{filename}") print(f"[{get_current_time_str()}] 图像已保存至:{filename}")
# 进行检测 # 执行违规检测
if img is not None: has_violation, violation_type, details = detector.detect_violations(img)
has_violation, violation_type, details = detector.detect_violations(img) if has_violation:
if has_violation: print(
print(f"检测到违规 - 类型: {violation_type}, 详情: {details}") f"[{get_current_time_str()}] 客户端{self.client_ip}检测到违规 - 类型: {violation_type}, 详情: {details}")
# 发送检测结果回客户端(原有逻辑不变)
await self.websocket.send_json({ # 调用违规次数加一方法
"type": "detection_result", try:
"has_violation": has_violation, await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
"violation_type": violation_type, print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数已+1")
"details": details, except Exception as e:
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数更新失败 - {str(e)}")
})
else: # 发送「危险通知」
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:未检测到任何违规内容") danger_msg = {
"type": "danger",
"timestamp": get_current_time_str(),
"client_ip": self.client_ip
}
await self.websocket.send_json(danger_msg)
else: else:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}无法解析图像数据") print(f"[{get_current_time_str()}] 客户端{self.client_ip}未检测到违规")
except Exception as e: except Exception as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{self.client_ip}:图像处理错误 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}:图像处理错误 - {str(e)}")
# 全局连接管理IP -> 连接实例) # 全局状态管理
connected_clients: Dict[str, ClientConnection] = {} connected_clients: Dict[str, ClientConnection] = {}
# 心跳任务(全局引用,用于关闭时清理)
heartbeat_task: Optional[asyncio.Task] = None heartbeat_task: Optional[asyncio.Task] = None
# -------------------------- 心跳检查逻辑(原有逻辑不变) -------------------------- # 心跳检查(定时清理超时客户端 + 调用离线状态更新方法)
async def heartbeat_checker(): async def heartbeat_checker():
while True: while True:
now = datetime.datetime.now() current_time = get_current_time_str()
# 1. 筛选超时客户端(避免遍历中修改字典)
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()] timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
# 2. 处理超时连接(关闭+移除)
if timeout_ips: if timeout_ips:
print(f"[{now:%H:%M:%S}] 心跳检查:{len(timeout_ips)}个客户端超时({timeout_ips}") print(f"[{current_time}] 心跳检查:{len(timeout_ips)}个客户端超时(IP{timeout_ips}")
for ip in timeout_ips: for ip in timeout_ips:
try: try:
# 取消消费者任务 conn = connected_clients[ip]
if connected_clients[ip].consumer_task and not connected_clients[ip].consumer_task.done(): if conn.consumer_task and not conn.consumer_task.done():
connected_clients[ip].consumer_task.cancel() conn.consumer_task.cancel()
await connected_clients[ip].websocket.close(code=1008, reason="心跳超时") await conn.websocket.close(code=1008, reason="心跳超时")
# 超时设为离线并记录
try:
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
action_data = DeviceActionCreate(client_ip=ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{ip}:已标记为离线并记录操作")
except Exception as e:
print(f"[{current_time}] 客户端{ip}:离线状态更新失败 - {str(e)}")
finally: finally:
connected_clients.pop(ip, None) connected_clients.pop(ip, None)
else: else:
print(f"[{now:%H:%M:%S}] 心跳检查:{len(connected_clients)}个客户端在线,无超时") print(f"[{current_time}] 心跳检查:{len(connected_clients)}个客户端在线")
# 3. 等待下一轮检查
await asyncio.sleep(HEARTBEAT_INTERVAL) await asyncio.sleep(HEARTBEAT_INTERVAL)
# -------------------------- 应用生命周期(原有逻辑不变) -------------------------- # 应用生命周期管理
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
global heartbeat_task global heartbeat_task
# 启动心跳任务
heartbeat_task = asyncio.create_task(heartbeat_checker()) heartbeat_task = asyncio.create_task(heartbeat_checker())
print(f"[{datetime.datetime.now():%H:%M:%S}] 心跳任务启动ID{id(heartbeat_task)}") print(f"[{get_current_time_str()}] 全局心跳检查任务启动(任务ID{id(heartbeat_task)}")
yield yield
# 关闭时取消心跳任务
if heartbeat_task and not heartbeat_task.done(): if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel() heartbeat_task.cancel()
try: try:
await heartbeat_task await heartbeat_task
print(f"[{datetime.datetime.now():%H:%M:%S}] 心跳任务已取消") print(f"[{get_current_time_str()}] 全局心跳检查任务已取消")
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
# -------------------------- 消息处理(文本/心跳逻辑不变,二进制逻辑保留) -------------------------- # 消息处理工具函数
async def send_heartbeat_ack(client_ip: str): async def send_heartbeat_ack(conn: ClientConnection):
"""回复心跳确认(原有逻辑不变)"""
if client_ip not in connected_clients:
return False
try: try:
ack = { heartbeat_ack_msg = {
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "type": "heart",
"type": "heartbeat" "timestamp": get_current_time_str(),
"client_ip": conn.client_ip
} }
await connected_clients[client_ip].websocket.send_json(ack) await conn.websocket.send_json(heartbeat_ack_msg)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:已发送心跳确认")
return True return True
except Exception: except Exception as e:
connected_clients.pop(client_ip, None) connected_clients.pop(conn.client_ip, None)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:心跳确认发送失败 - {str(e)}")
return False return False
async def handle_text_msg(client_ip: str, text: str, conn: ClientConnection): async def handle_text_msg(conn: ClientConnection, text: str):
"""处理文本消息(核心:心跳+JSON解析原有逻辑不变"""
try: try:
msg = json.loads(text) msg = json.loads(text)
# 仅处理心跳类型消息 if msg.get("type") == "heart":
if msg.get("type") == "heartbeat":
conn.update_heartbeat() conn.update_heartbeat()
await send_heartbeat_ack(client_ip) await send_heartbeat_ack(conn)
else: else:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}收到文本消息{msg}") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}未知文本消息类型({msg.get('type')}")
except json.JSONDecodeError: except json.JSONDecodeError:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}无效JSON消息") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}无效JSON文本消息")
async def handle_binary_msg(client_ip: str, data: bytes): async def handle_binary_msg(conn: ClientConnection, data: bytes):
"""处理二进制消息(原有逻辑不变,因客户端仅在收到允许信号后发送,队列不会满)"""
if client_ip not in connected_clients:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接不存在,丢弃{len(data)}字节数据")
return
conn = connected_clients[client_ip]
# 检查队列是否已满(理论上不会触发,因客户端按信号发送)
if conn.frame_queue.full():
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:队列已满,丢弃{len(data)}字节数据")
return
# 队列未满,添加帧到队列
try: try:
conn.frame_queue.put_nowait(data) conn.frame_queue.put_nowait(data)
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}已接收{len(data)}字节二进制数据,加入队列") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}图像数据({len(data)}字节)已加入队列")
except asyncio.QueueFull: except asyncio.QueueFull:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:队列突然满了,丢弃{len(data)}字节数据") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}队列已满,丢弃当前图像数据")
# WebSocket路由配置
ws_router = APIRouter()
# -------------------------- WebSocket核心端点修改连接初始化逻辑 --------------------------
@ws_router.websocket(WS_ENDPOINT) @ws_router.websocket(WS_ENDPOINT)
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
# 接受连接 + 获取客户端IP
await websocket.accept() await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown" client_ip = websocket.client.host if websocket.client else "unknown_ip"
now = datetime.datetime.now() current_time = get_current_time_str()
print(f"[{now:%H:%M:%S}] 客户端{client_ip}连接成功") print(f"[{current_time}] 客户端{client_ip}WebSocket连接已建立")
is_online_updated = False
consumer_task = None
try: try:
# 处理重复连接(关闭旧连接) # 处理重复连接
if client_ip in connected_clients: if client_ip in connected_clients:
# 取消旧连接的消费者任务 old_conn = connected_clients[client_ip]
if connected_clients[client_ip].consumer_task and not connected_clients[client_ip].consumer_task.done(): if old_conn.consumer_task and not old_conn.consumer_task.done():
connected_clients[client_ip].consumer_task.cancel() old_conn.consumer_task.cancel()
await connected_clients[client_ip].websocket.close(code=1008, reason="同一IP新连接") await old_conn.websocket.close(code=1008, reason="同一IP新连接建立")
connected_clients.pop(client_ip) connected_clients.pop(client_ip)
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:关闭旧连接") print(f"[{current_time}] 客户端{client_ip}关闭旧连接")
# 注册新连接 # 注册新连接
new_conn = ClientConnection(websocket, client_ip) new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn connected_clients[client_ip] = new_conn
new_conn.start_consumer()
# 初始许可:连接建立后立即发一次,让客户端知道可发第一帧(后续靠取帧后自动发)
await new_conn.send_frame_permit()
# 启动帧消费任务 # 标记上线并记录
consumer_task = new_conn.start_consumer() try:
# ---------- 修改:客户端刚连接时,队列空,立即发送「允许发送帧」信号 ---------- await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
await new_conn.send_allow_send_frame() action_data = DeviceActionCreate(client_ip=client_ip, action=1)
print(f"[{now:%H:%M:%S}] 客户端{client_ip}:注册成功,已启动帧消费任务,当前在线{len(connected_clients)}") await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{client_ip}:已标记为在线并记录操作")
is_online_updated = True
except Exception as e:
print(f"[{current_time}] 客户端{client_ip}:上线状态更新失败 - {str(e)}")
# 循环接收消息(原有逻辑不变) print(f"[{current_time}] 客户端{client_ip}:新连接注册成功,在线数:{len(connected_clients)}")
# 消息循环
while True: while True:
data = await websocket.receive() data = await websocket.receive()
if "text" in data: if "text" in data:
await handle_text_msg(client_ip, data["text"], new_conn) await handle_text_msg(new_conn, data["text"])
elif "bytes" in data: elif "bytes" in data:
await handle_binary_msg(client_ip, data["bytes"]) await handle_binary_msg(new_conn, data["bytes"])
# 异常处理(断开/错误)
except WebSocketDisconnect as e: except WebSocketDisconnect as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:主动断开(代码:{e.code}") print(f"[{get_current_time_str()}] 客户端{client_ip}:主动断开连接(代码:{e.code}")
except Exception as e: except Exception as e:
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}:连接异常{str(e)[:50]}") print(f"[{get_current_time_str()}] 客户端{client_ip}:连接异常 - {str(e)[:50]}")
finally: finally:
# 清理连接和任务 # 清理资源并标记离线
if client_ip in connected_clients: if client_ip in connected_clients:
# 取消消费者任务 conn = connected_clients[client_ip]
if connected_clients[client_ip].consumer_task and not connected_clients[client_ip].consumer_task.done(): if conn.consumer_task and not conn.consumer_task.done():
connected_clients[client_ip].consumer_task.cancel() conn.consumer_task.cancel()
# 主动/异常断开时标记离线
if is_online_updated:
try:
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
action_data = DeviceActionCreate(client_ip=client_ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后已标记为离线")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后离线更新失败 - {str(e)}")
connected_clients.pop(client_ip, None) connected_clients.pop(client_ip, None)
print(f"[{datetime.datetime.now():%H:%M:%S}] 客户端{client_ip}连接已清理,当前在线{len(connected_clients)}") print(f"[{get_current_time_str()}] 客户端{client_ip}资源已清理,在线数:{len(connected_clients)}")