优化代码风格

This commit is contained in:
ZZX9599
2025-09-08 17:34:23 +08:00
parent 9b3d20511a
commit 8ceb92c572
20 changed files with 223 additions and 192 deletions

View File

@ -12,8 +12,4 @@ charset = utf8mb4
[jwt]
secret_key = 6tsieyd87wefdw2wgeduwte23rfcsd
algorithm = HS256
access_token_expire_minutes = 30
[live]
rtmp_url = rtmp://192.168.110.25:1935/live/
webrtc_url = http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=
access_token_expire_minutes = 30

View File

@ -22,7 +22,7 @@ _task_counter_lock = threading.Lock() # 任务计数锁
# -------------------------- 工具函数 --------------------------
def _get_next_task_id():
"""获取唯一任务ID用于日志追踪"""
"""获取唯一任务ID用于日志追踪"""
global _task_counter
with _task_counter_lock:
_task_counter += 1
@ -65,11 +65,11 @@ def _init_thread_pool():
max_workers=MAX_WORKERS,
thread_name_prefix="DetectionThread"
)
print(f"=== 线程池初始化完成最大线程数: {MAX_WORKERS} ===")
print(f"=== 线程池初始化完成最大线程数: {MAX_WORKERS} ===")
def shutdown():
"""关闭线程池释放资源"""
"""关闭线程池释放资源"""
global _executor
with _executor_lock:
if _executor is not None:
@ -82,7 +82,7 @@ def shutdown():
def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple:
"""在子线程中执行检测逻辑"""
thread_name = threading.current_thread().name
print(f"任务[{task_id}] 开始执行线程: {thread_name}")
print(f"任务[{task_id}] 开始执行线程: {thread_name}")
try:
# 按照优先级执行检测
@ -98,7 +98,7 @@ def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple:
print(f"任务[{task_id}] {detector}检测结果: {'成功' if flag else '失败'}")
if flag:
print(f"任务[{task_id}] 完成检测使用检测器: {detector}")
print(f"任务[{task_id}] 完成检测使用检测器: {detector}")
return (True, result, detector, task_id)
# 所有检测器均未检测到结果
@ -116,14 +116,14 @@ def detect(frame: np.ndarray) -> Future:
提交检测任务到线程池
参数:
frame: 待检测图像(ndarray格式cv2.imdecode生成)
frame: 待检测图像(ndarray格式cv2.imdecode生成)
返回:
Future对象通过result()方法获取检测结果
Future对象通过result()方法获取检测结果
"""
# 确保模型已加载
if not _model_loaded:
print("警告: 模型尚未加载将自动加载")
print("警告: 模型尚未加载将自动加载")
load_model()
# 生成任务ID
@ -131,6 +131,6 @@ def detect(frame: np.ndarray) -> Future:
# 提交任务到线程池
future = _executor.submit(_detect_in_thread, frame, task_id)
print(f"任务[{task_id}] 已提交到线程池")
print(f"任务[{task_id}]: 已提交到线程池")
return future

View File

@ -16,7 +16,7 @@ try:
pynvml.nvmlInit()
_nvml_available = True
except ImportError:
print("警告: pynvml库未安装无法检测GPU状态将默认使用0号GPU")
print("警告: pynvml库未安装无法检测GPU状态将默认使用0号GPU")
_nvml_available = False
# 全局变量
@ -58,7 +58,7 @@ def check_gpu_availability(gpu_id, threshold=0.7):
def select_best_gpu(preferred_gpus=[0, 1]):
"""选择最佳可用GPU严格按照首选列表顺序检查优先使用0号GPU"""
"""选择最佳可用GPU严格按照首选列表顺序检查优先使用0号GPU"""
# 首先检查首选GPU列表
for gpu_id in preferred_gpus:
try:
@ -68,17 +68,17 @@ def select_best_gpu(preferred_gpus=[0, 1]):
# 检查GPU是否可用
if check_gpu_availability(gpu_id):
print(f"GPU {gpu_id} 可用将使用该GPU")
print(f"GPU {gpu_id} 可用将使用该GPU")
return gpu_id
else:
if gpu_id == 0:
print(f"GPU 0 内存使用率过高(繁忙)尝试切换到其他GPU")
print(f"GPU 0 内存使用率过高(繁忙)尝试切换到其他GPU")
except Exception as e:
print(f"GPU {gpu_id} 不存在或无法访问: {e}")
continue
# 如果所有首选GPU都不可用返回-1表示使用CPU
print("所有指定的GPU都不可用将使用CPU进行计算")
# 如果所有首选GPU都不可用返回-1表示使用CPU
print("所有指定的GPU都不可用将使用CPU进行计算")
return -1
@ -122,12 +122,12 @@ def _release_engine():
def _monitor_thread():
"""监控线程检查并释放超时未使用的资源"""
"""监控线程检查并释放超时未使用的资源"""
global _ref_count, _last_used_time, _face_app
while True:
time.sleep(5) # 每5秒检查一次
with _lock:
# 只有当引擎存在、没有引用且超时才释放
# 只有当引擎存在、没有引用且超时才释放
if _face_app and _ref_count == 0 and not _is_releasing:
elapsed = time.time() - _last_used_time
if elapsed > _release_timeout:
@ -136,7 +136,7 @@ def _monitor_thread():
def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
"""加载人脸识别模型及已知人脸特征库默认优先使用0号GPU"""
"""加载人脸识别模型及已知人脸特征库默认优先使用0号GPU"""
global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id
# 确保监控线程只启动一次
@ -144,11 +144,11 @@ def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
threading.Thread(target=_monitor_thread, daemon=True, name="FaceMonitor").start()
print("Face monitor thread started")
# 如果正在释放中等待释放完成
# 如果正在释放中等待释放完成
while _is_releasing:
time.sleep(0.1)
# 如果已经初始化直接返回
# 如果已经初始化直接返回
if _face_app:
return True
@ -158,7 +158,7 @@ def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
print("正在初始化InsightFace人脸识别引擎...")
_face_app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
# 选择合适的GPU默认优先使用0号
# 选择合适的GPU默认优先使用0号
ctx_id = 0
if prefer_gpu:
ctx_id = select_best_gpu(preferred_gpus)
@ -166,9 +166,9 @@ def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
_used_gpu_id = ctx_id if _using_gpu else -1
if _using_gpu:
print(f"成功初始化使用GPU {ctx_id} 进行计算")
print(f"成功初始化使用GPU {ctx_id} 进行计算")
else:
print("成功初始化使用CPU进行计算")
print("成功初始化使用CPU进行计算")
# 准备模型
_face_app.prepare(ctx_id=ctx_id, det_size=(640, 640))
@ -188,10 +188,10 @@ def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
for person_name, eigenvalue_data in face_data.items():
# 处理特征值数据 - 兼容数组和字符串两种格式
if isinstance(eigenvalue_data, np.ndarray):
# 如果已经是numpy数组直接使用
# 如果已经是numpy数组直接使用
eigenvalue = eigenvalue_data.astype(np.float32)
elif isinstance(eigenvalue_data, str):
# 清理字符串移除方括号、换行符和多余空格
# 清理字符串: 移除方括号、换行符和多余空格
cleaned = eigenvalue_data.replace('[', '').replace(']', '').replace('\n', '').strip()
# 按空格或逗号分割(处理可能的不同分隔符)
values = [v for v in cleaned.split() if v]
@ -217,7 +217,7 @@ def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
def detect(frame, threshold=0.4):
"""检测并识别人脸返回结果元组(是否匹配到已知人脸, 结果字符串)"""
"""检测并识别人脸返回结果元组(是否匹配到已知人脸, 结果字符串)"""
global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id
global _ref_count, _last_used_time
@ -248,7 +248,7 @@ def detect(frame, threshold=0.4):
return (False, "人脸识别引擎不可用或未初始化")
try:
# 如果使用GPU确保输入帧在处理前是连续的数组
# 如果使用GPU确保输入帧在处理前是连续的数组
if _using_gpu and not frame.flags.contiguous:
frame = np.ascontiguousarray(frame)
@ -285,7 +285,7 @@ def detect(frame, threshold=0.4):
# 判断匹配结果
is_match = max_sim >= threshold
if is_match:
has_matched = True # 只要有一个匹配成功就标记为True
has_matched = True # 只要有一个匹配成功就标记为True
bbox = face.bbox
result_parts.append(
@ -298,12 +298,12 @@ def detect(frame, threshold=0.4):
else:
result_str = "; ".join(result_parts)
# 减少引用计数确保线程安全
# 减少引用计数确保线程安全
with _lock:
_ref_count = max(0, _ref_count - 1)
# 持续使用时更新最后使用时间
if _ref_count > 0:
_last_used_time = time.time()
# 第一个返回值为是否匹配到已知人脸
# 第一个返回值为: 是否匹配到已知人脸
return (has_matched, result_str)

View File

@ -61,12 +61,12 @@ def _release_engine():
def _monitor_thread():
"""监控线程优化检查逻辑"""
"""监控线程优化检查逻辑"""
global _ref_count, _last_used_time, _ocr_engine
while True:
time.sleep(5) # 每5秒检查一次
with _lock:
# 只有当引擎存在、没有引用且超时才释放
# 只有当引擎存在、没有引用且超时才释放
if _ocr_engine and _ref_count == 0 and not _is_releasing:
elapsed = time.time() - _last_used_time
if elapsed > _release_timeout:
@ -100,7 +100,7 @@ def load_model():
def detect(frame):
"""OCR检测优化引用计数管理"""
"""OCR检测优化引用计数管理"""
global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time
# 验证前置条件
@ -178,7 +178,7 @@ def detect(frame):
return (False, f"检测错误: {str(e)}")
finally:
# 减少引用计数确保线程安全
# 减少引用计数确保线程安全
with _lock:
_ref_count = max(0, _ref_count - 1)
# 持续使用时更新最后使用时间

View File

@ -1,6 +1,5 @@
import os
import cv2
from ultralytics import YOLO
# 全局变量
@ -24,7 +23,7 @@ def load_model():
def detect(frame, conf_threshold=0.2):
"""YOLO目标检测返回(是否识别到, 结果字符串)"""
"""YOLO目标检测返回(是否识别到, 结果字符串)"""
global _yolo_model
if not _yolo_model or frame is None:

View File

@ -14,4 +14,3 @@ config.read(config_path, encoding="utf-8")
SERVER_CONFIG = config["server"]
MYSQL_CONFIG = config["mysql"]
JWT_CONFIG = config["jwt"]
LIVE_CONFIG = config["live"]

View File

@ -47,7 +47,7 @@ if __name__ == "__main__":
YOLO_MODEL_PATH = r"/core/models\best.pt"
OCR_CONFIG_PATH = r"/core/config\config.yaml"
# 初始化项目默认端口设为8000避免初始化失败时port未定义
# 初始化项目默认端口设为8000避免初始化失败时port未定义
port = int(SERVER_CONFIG.get("port", 8000))
# 启动 UVicorn 服务

View File

@ -9,8 +9,6 @@ from passlib.context import CryptContext
from ds.config import JWT_CONFIG
from ds.db import db
# 移除这里的 from service.user_service import UserResponse 导入
# ------------------------------
# 密码加密配置
# ------------------------------
@ -23,7 +21,7 @@ SECRET_KEY = JWT_CONFIG["secret_key"]
ALGORITHM = JWT_CONFIG["algorithm"]
ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
# OAuth2 依赖(从请求头获取 Token、格式Bearer <token>
# OAuth2 依赖(从请求头获取 Token、格式: Bearer <token>
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
@ -63,7 +61,7 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
# ------------------------------
def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型注解
"""从 Token 中解析用户信息、验证通过后返回当前用户"""
# 延迟导入打破循环依赖
# 延迟导入打破循环依赖
from schema.user_schema import UserResponse # 在这里导入
# 认证失败异常
@ -101,4 +99,4 @@ def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型
except Exception as e:
raise credentials_exception from e
finally:
db.close_connection(conn, cursor)
db.close_connection(conn, cursor)

View File

@ -8,7 +8,7 @@ from schema.response_schema import APIResponse
async def global_exception_handler(request: Request, exc: Exception):
"""全局异常处理器所有未捕获的异常都会在这里统一处理"""
"""全局异常处理器: 所有未捕获的异常都会在这里统一处理"""
# 1. 请求参数验证错误Pydantic 校验失败)
if isinstance(exc, RequestValidationError):
error_details = []
@ -18,7 +18,7 @@ async def global_exception_handler(request: Request, exc: Exception):
status_code=status.HTTP_400_BAD_REQUEST,
content=APIResponse(
code=400,
message=f"请求参数错误{'; '.join(error_details)}",
message=f"请求参数错误: {'; '.join(error_details)}",
data=None
).model_dump()
)
@ -52,7 +52,7 @@ async def global_exception_handler(request: Request, exc: Exception):
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=APIResponse(
code=500,
message=f"数据库错误{str(exc)}",
message=f"数据库错误: {str(exc)}",
data=None
).model_dump()
)
@ -62,7 +62,7 @@ async def global_exception_handler(request: Request, exc: Exception):
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=APIResponse(
code=500,
message=f"服务器内部错误{str(exc)}",
message=f"服务器内部错误: {str(exc)}",
data=None
).model_dump()
)

View File

@ -4,12 +4,12 @@ from pydantic import BaseModel, Field
# ------------------------------
# 请求模型(新增记录用,极简)
# 请求模型
# ------------------------------
class DeviceActionCreate(BaseModel):
"""设备操作记录创建模型0=离线1=上线)"""
"""设备操作记录创建模型0=离线1=上线)"""
client_ip: str = Field(..., description="客户端IP")
action: int = Field(..., ge=0, le=1, description="操作状态0=离线1=上线)")
action: int = Field(..., ge=0, le=1, description="操作状态0=离线1=上线)")
# ------------------------------
@ -19,7 +19,7 @@ class DeviceActionResponse(BaseModel):
"""设备操作记录响应模型(与自增表对齐)"""
id: int = Field(..., description="自增主键ID")
client_ip: Optional[str] = Field(None, description="客户端IP")
action: Optional[int] = Field(None, description="操作状态0=离线1=上线)")
action: Optional[int] = Field(None, description="操作状态0=离线1=上线)")
created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间")

View File

@ -6,8 +6,8 @@ from pydantic import BaseModel, Field
# 请求模型(前端传参校验)
# ------------------------------
class FaceCreateRequest(BaseModel):
"""创建人脸记录请求模型无需ID由数据库自增)"""
name: str = Field(None, max_length=255, description="名称(可选最长255字符")
"""创建人脸记录请求模型无需ID由数据库自增)"""
name: str = Field(None, max_length=255, description="名称(可选最长255字符")
class FaceUpdateRequest(BaseModel):
@ -20,7 +20,7 @@ class FaceUpdateRequest(BaseModel):
# 响应模型(后端返回数据)
# ------------------------------
class FaceResponse(BaseModel):
"""人脸记录响应模型仍包含ID由数据库生成后返回)"""
"""人脸记录响应模型仍包含ID由数据库生成后返回)"""
id: int = Field(..., description="主键ID数据库自增")
name: str = Field(None, description="名称")
eigenvalue: str | None = Field(None, description="特征(可为空)")

View File

@ -5,9 +5,9 @@ from pydantic import BaseModel, Field
class APIResponse(BaseModel):
"""统一 API 响应模型(所有接口必返此格式)"""
code: int = Field(..., description="状态码200=成功、4xx=客户端错误、5xx=服务端错误")
message: str = Field(..., description="响应信息成功/错误描述")
data: Optional[Any] = Field(None, description="响应数据成功时返回、错误时为 None")
code: int = Field(..., description="状态码: 200=成功、4xx=客户端错误、5xx=服务端错误")
message: str = Field(..., description="响应信息: 成功/错误描述")
data: Optional[Any] = Field(None, description="响应数据: 成功时返回、错误时为 None")
# Pydantic V2 配置(支持从 ORM 对象转换)
model_config = {"from_attributes": True}

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
# ------------------------------
class SensitiveCreateRequest(BaseModel):
"""创建敏感信息记录请求模型"""
# 移除了id字段由数据库自动生成
# 移除了id字段由数据库自动生成
name: str = Field(None, max_length=255, description="名称")

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, Query
from fastapi import APIRouter, Query, Path
from mysql.connector import Error as MySQLError
from ds.db import db
@ -9,7 +9,6 @@ from schema.device_action_schema import (
)
from schema.response_schema import APIResponse
# 路由配置
router = APIRouter(
prefix="/device/actions",
@ -18,11 +17,11 @@ router = APIRouter(
# ------------------------------
# 内部方法新增设备操作记录适配id自增
# 内部方法: 新增设备操作记录适配id自增
# ------------------------------
def add_device_action(action_data: DeviceActionCreate) -> DeviceActionResponse:
"""
新增设备操作记录(内部方法非接口)
新增设备操作记录(内部方法非接口)
:param action_data: 含client_ip和action0/1
:return: 新增的完整记录
"""
@ -32,7 +31,7 @@ def add_device_action(action_data: DeviceActionCreate) -> DeviceActionResponse:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 插入SQLid自增依赖数据库自动生成)
# 插入SQLid自增依赖数据库自动生成)
insert_query = """
INSERT INTO device_action
(client_ip, action, created_at, updated_at)
@ -54,20 +53,20 @@ def add_device_action(action_data: DeviceActionCreate) -> DeviceActionResponse:
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"新增记录失败{str(e)}") from e
raise Exception(f"新增记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 接口分页查询操作记录列表(仅返回 total + device_actions
# 接口: 分页查询操作记录列表(仅返回 total + device_actions
# ------------------------------
@router.get("/list", response_model=APIResponse, summary="分页查询设备操作记录")
async def get_device_action_list(
page: int = Query(1, ge=1, description="页码默认1"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100"),
page: int = Query(1, ge=1, description="页码默认1"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100"),
client_ip: str = Query(None, description="按客户端IP筛选"),
action: int = Query(None, ge=0, le=1, description="按状态筛选0=离线1=上线)")
action: int = Query(None, ge=0, le=1, description="按状态筛选0=离线1=上线)")
):
conn = None
cursor = None
@ -75,7 +74,7 @@ async def get_device_action_list(
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 构建筛选条件(参数化查询避免注入)
# 1. 构建筛选条件(参数化查询避免注入)
where_clause = []
params = []
if client_ip:
@ -92,13 +91,13 @@ async def get_device_action_list(
cursor.execute(count_sql, params)
total = cursor.fetchone()["total"]
# 3. 分页查询记录(按创建时间倒序确保最新记录在前)
# 3. 分页查询记录(按创建时间倒序确保最新记录在前)
offset = (page - 1) * page_size
list_sql = "SELECT * FROM device_action"
if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause)
list_sql += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset]) # 追加分页参数page/page_size仅用于查询不返回)
params.extend([page_size, offset]) # 追加分页参数page/page_size仅用于查询不返回)
cursor.execute(list_sql, params)
action_list = cursor.fetchall()
@ -114,6 +113,46 @@ async def get_device_action_list(
)
except MySQLError as e:
raise Exception(f"查询记录失败{str(e)}") from e
raise Exception(f"查询记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
db.close_connection(conn, cursor)
@router.get("/{client_ip}", response_model=APIResponse, summary="根据IP查询设备操作记录")
async def get_device_actions_by_ip(
client_ip: str = Path(..., description="客户端IP地址")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 查询总记录数
count_sql = "SELECT COUNT(*) AS total FROM device_action WHERE client_ip = %s"
cursor.execute(count_sql, (client_ip,))
total = cursor.fetchone()["total"]
# 2. 查询该IP的所有记录按创建时间倒序
list_sql = """
SELECT * FROM device_action
WHERE client_ip = %s
ORDER BY created_at DESC
"""
cursor.execute(list_sql, (client_ip,))
action_list = cursor.fetchall()
# 3. 返回结果
return APIResponse(
code=200,
message="查询成功",
data=DeviceActionListResponse(
total=total,
device_actions=[DeviceActionResponse(**item) for item in action_list]
)
)
except MySQLError as e:
raise Exception(f"查询记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

View File

@ -37,7 +37,7 @@ def increment_alarm_count_by_ip(client_ip: str) -> bool:
if not cursor.fetchone():
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
# 报警次数加1并更新时间戳
# 报警次数加1并更新时间戳
update_query = """
UPDATE devices
SET alarm_count = alarm_count + 1,
@ -51,7 +51,7 @@ def increment_alarm_count_by_ip(client_ip: str) -> bool:
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"更新报警次数失败{str(e)}") from e
raise Exception(f"更新报警次数失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@ -99,7 +99,7 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"更新设备在线状态失败{str(e)}") from e
raise Exception(f"更新设备在线状态失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@ -124,7 +124,7 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): #
# 返回信息
return APIResponse(
code=200,
message=f"设备IP {device_data.ip} 已存在返回已有设备信息",
message=f"设备IP {device_data.ip} 已存在返回已有设备信息",
data=DeviceResponse(** existing_device)
)
@ -173,9 +173,9 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): #
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"创建设备失败{str(e)}") from e
raise Exception(f"创建设备失败: {str(e)}") from e
except json.JSONDecodeError as e:
raise Exception(f"设备详细信息JSON序列化失败{str(e)}") from e
raise Exception(f"设备详细信息JSON序列化失败: {str(e)}") from e
except Exception as e:
if conn:
conn.rollback()
@ -185,8 +185,8 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): #
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
async def get_device_list(
page: int = Query(1, ge=1, description="页码默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100之间"),
page: int = Query(1, ge=1, description="页码默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100之间"),
device_type: str = Query(None, description="按设备类型筛选"),
online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选")
):
@ -233,6 +233,6 @@ async def get_device_list(
)
except MySQLError as e:
raise Exception(f"获取设备列表失败{str(e)}") from e
raise Exception(f"获取设备列表失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

View File

@ -18,27 +18,27 @@ router = APIRouter(
# ------------------------------
# 1. 创建人脸记录(核心修正ID 数据库自增前端无需传)
# 1. 创建人脸记录(核心修正: ID 数据库自增前端无需传)
# ------------------------------
@router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件ID自增")
@router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件ID自增")
async def create_face(
# 前端仅需传name可选Form格式、file必传文件)
# 前端仅需传: name可选Form格式、file必传文件)
name: str = Form(None, max_length=255, description="名称(可选)"),
file: UploadFile = File(..., description="人脸文件(必传暂不处理内容)")
file: UploadFile = File(..., description="人脸文件(必传暂不处理内容)")
):
"""
创建人脸记录
创建人脸记录:
- 需登录认证
- 前端传参multipart/form-data 表单name 可选file 必传)
- ID 由数据库自动生成无需前端传入
- 暂不处理文件内容eigenvalue 设为 None
- 前端传参: multipart/form-data 表单name 可选file 必传)
- ID 由数据库自动生成无需前端传入
- 暂不处理文件内容eigenvalue 设为 None
"""
# 调用你的方法
conn = None
cursor = None
try:
# 1. 用模型校验 name仅校验长度无需ID
# 1. 用模型校验 name仅校验长度无需ID
face_create = FaceCreateRequest(name=name)
conn = db.get_connection()
@ -57,9 +57,9 @@ async def create_face(
)
# 打印数组长度
print(f"文件大小{len(file_content)} 字节")
print(f"文件大小: {len(file_content)} 字节")
# 2. 插入数据库无需传 ID自增只传 name 和 eigenvalueNone
# 2. 插入数据库: 无需传 ID自增只传 name 和 eigenvalueNone
insert_query = """
INSERT INTO face (name, eigenvalue)
VALUES (%s, %s)
@ -67,7 +67,7 @@ async def create_face(
cursor.execute(insert_query, (face_create.name, str(eigenvalue)))
conn.commit()
# 3. 获取数据库自动生成的 ID关键用 LAST_INSERT_ID() 查刚插入的记录)
# 3. 获取数据库自动生成的 ID关键: 用 LAST_INSERT_ID() 查刚插入的记录)
select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()"
cursor.execute(select_new_query)
created_face = cursor.fetchone()
@ -75,12 +75,12 @@ async def create_face(
if not created_face:
raise HTTPException(
status_code=500,
detail="创建人脸记录成功但无法获取新创建的记录"
detail="创建人脸记录成功但无法获取新创建的记录"
)
return APIResponse(
code=201,
message=f"人脸记录创建成功ID{created_face['id']}文件名{file.filename}",
message=f"人脸记录创建成功ID: {created_face['id']}文件名: {file.filename}",
data=FaceResponse(** created_face)
)
except MySQLError as e:
@ -89,13 +89,13 @@ async def create_face(
# 改为使用HTTPException
raise HTTPException(
status_code=500,
detail=f"创建人脸记录失败{str(e)}"
detail=f"创建人脸记录失败: {str(e)}"
) from e
except Exception as e:
# 捕获其他可能的异常
raise HTTPException(
status_code=500,
detail=f"服务器错误{str(e)}"
detail=f"服务器错误: {str(e)}"
) from e
finally:
await file.close() # 关闭文件流
@ -113,11 +113,11 @@ async def create_face(
eigenvalue = str(eigenvalue)
# ------------------------------
# 2. 获取单个人脸记录(不变用自增ID查询
# 2. 获取单个人脸记录(不变用自增ID查询
# ------------------------------
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
async def get_face(
face_id: int, # 这里的 ID 是数据库自增的前端从创建响应中获取
face_id: int, # 这里的 ID 是数据库自增的前端从创建响应中获取
current_user: UserResponse = Depends(get_current_user)
):
conn = None
@ -145,7 +145,7 @@ async def get_face(
# 改为使用HTTPException
raise HTTPException(
status_code=500,
detail=f"查询人脸记录失败{str(e)}"
detail=f"查询人脸记录失败: {str(e)}"
) from e
finally:
db.close_connection(conn, cursor)
@ -176,14 +176,14 @@ async def get_all_faces(
except MySQLError as e:
raise HTTPException(
status_code=500,
detail=f"查询所有人脸记录失败{str(e)}"
detail=f"查询所有人脸记录失败: {str(e)}"
) from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 4. 更新人脸记录(不变用自增ID更新
# 4. 更新人脸记录(不变用自增ID更新
# ------------------------------
@router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录")
async def update_face(
@ -240,14 +240,14 @@ async def update_face(
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"更新人脸记录失败{str(e)}"
detail=f"更新人脸记录失败: {str(e)}"
) from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 5. 删除人脸记录(不变用自增ID删除
# 5. 删除人脸记录(不变用自增ID删除
# ------------------------------
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
async def delete_face(
@ -283,7 +283,7 @@ async def delete_face(
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"删除人脸记录失败{str(e)}"
detail=f"删除人脸记录失败: {str(e)}"
) from e
finally:
db.close_connection(conn, cursor)
@ -291,10 +291,10 @@ async def delete_face(
def get_all_face_name_with_eigenvalue() -> dict:
"""
获取所有人脸的名称及其对应的特征值组成字典返回
获取所有人脸的名称及其对应的特征值组成字典返回
key: 人脸名称name
value: 人脸特征值eigenvalue若名称重复则返回平均特征值
过滤掉name为None的记录避免字典key为None的情况
value: 人脸特征值eigenvalue若名称重复则返回平均特征值
: 过滤掉name为None的记录避免字典key为None的情况
"""
conn = None
cursor = None
@ -303,27 +303,27 @@ def get_all_face_name_with_eigenvalue() -> dict:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 2. 执行SQL查询只获取name非空的记录减少数据传输
# 2. 执行SQL查询: 只获取name非空的记录减少数据传输
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
cursor.execute(query)
faces = cursor.fetchall() # 返回结果列表套字典如 [{"name":"张三","eigenvalue":...}, ...]
faces = cursor.fetchall() # 返回结果: 列表套字典如 [{"name":"张三","eigenvalue":...}, ...]
# 3. 收集同一名称对应的所有特征值(处理名称重复场景)
name_to_eigenvalues = {}
for face in faces:
name = face["name"]
eigenvalue = face["eigenvalue"]
# 若名称已存在追加特征值;否则新建列表存储
# 若名称已存在追加特征值;否则新建列表存储
if name in name_to_eigenvalues:
name_to_eigenvalues[name].append(eigenvalue)
else:
name_to_eigenvalues[name] = [eigenvalue]
# 4. 构建最终字典重复名称取平均唯一名称直接取特征值
# 4. 构建最终字典: 重复名称取平均唯一名称直接取特征值
face_dict = {}
for name, eigenvalues in name_to_eigenvalues.items():
# 处理特征值多个则求平均单个则直接使用
# 处理特征值: 多个则求平均单个则直接使用
if len(eigenvalues) > 1:
# 调用外部方法计算平均特征值需确保binary_face_feature_handler已正确导入
face_dict[name] = get_average_feature(eigenvalues)
@ -334,8 +334,8 @@ def get_all_face_name_with_eigenvalue() -> dict:
return face_dict
except MySQLError as e:
# 捕获数据库异常添加上下文信息后重新抛出(便于定位问题)
raise Exception(f"获取人脸名称与特征值失败{str(e)}") from e
# 捕获数据库异常添加上下文信息后重新抛出(便于定位问题)
raise Exception(f"获取人脸名称与特征值失败: {str(e)}") from e
finally:
# 5. 无论是否异常均释放数据库连接和游标(避免资源泄漏)
# 5. 无论是否异常均释放数据库连接和游标(避免资源泄漏)
db.close_connection(conn, cursor)

View File

@ -21,7 +21,7 @@ router = APIRouter(
async def create_sensitive(
sensitive: SensitiveCreateRequest): # 添加了登录认证依赖
"""
创建敏感信息记录
创建敏感信息记录:
- 需登录认证
- 插入新的敏感信息记录到数据库ID由数据库自动生成
- 返回创建成功信息
@ -32,7 +32,7 @@ async def create_sensitive(
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 插入新敏感信息记录到数据库不包含ID由数据库自动生成)
# 插入新敏感信息记录到数据库不包含ID由数据库自动生成)
insert_query = """
INSERT INTO sensitives (name)
VALUES (%s)
@ -56,7 +56,7 @@ async def create_sensitive(
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"创建敏感信息记录失败{str(e)}") from e
raise Exception(f"创建敏感信息记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@ -71,7 +71,7 @@ async def get_sensitive(
current_user: UserResponse = Depends(get_current_user) # 需登录认证
):
"""
获取单个敏感信息记录
获取单个敏感信息记录:
- 需登录认证
- 根据ID查询敏感信息记录
- 返回查询到的敏感信息
@ -98,7 +98,7 @@ async def get_sensitive(
data=SensitiveResponse(**sensitive)
)
except MySQLError as e:
raise Exception(f"查询敏感信息记录失败{str(e)}") from e
raise Exception(f"查询敏感信息记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@ -109,7 +109,7 @@ async def get_sensitive(
@router.get("", response_model=APIResponse, summary="获取所有敏感信息记录")
async def get_all_sensitives():
"""
获取所有敏感信息记录
获取所有敏感信息记录:
- 需登录认证
- 查询所有敏感信息记录(不需要分页)
- 返回所有敏感信息列表
@ -130,7 +130,7 @@ async def get_all_sensitives():
data=[SensitiveResponse(**sensitive) for sensitive in sensitives]
)
except MySQLError as e:
raise Exception(f"查询所有敏感信息记录失败{str(e)}") from e
raise Exception(f"查询所有敏感信息记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@ -145,7 +145,7 @@ async def update_sensitive(
current_user: UserResponse = Depends(get_current_user) # 需登录认证
):
"""
更新敏感信息记录
更新敏感信息记录:
- 需登录认证
- 根据ID更新敏感信息记录
- 返回更新后的敏感信息
@ -203,7 +203,7 @@ async def update_sensitive(
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"更新敏感信息记录失败{str(e)}") from e
raise Exception(f"更新敏感信息记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@ -217,7 +217,7 @@ async def delete_sensitive(
current_user: UserResponse = Depends(get_current_user) # 需登录认证
):
"""
删除敏感信息记录
删除敏感信息记录:
- 需登录认证
- 根据ID删除敏感信息记录
- 返回删除成功信息
@ -251,14 +251,14 @@ async def delete_sensitive(
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"删除敏感信息记录失败{str(e)}") from e
raise Exception(f"删除敏感信息记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
def get_all_sensitive_words() -> list[str]:
"""
获取所有敏感词返回字符串数组
获取所有敏感词返回字符串数组
返回:
list[str]: 包含所有敏感词的数组
@ -273,7 +273,7 @@ def get_all_sensitive_words() -> list[str]:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 执行查询只获取敏感词字段
# 执行查询只获取敏感词字段
query = "SELECT name FROM sensitives ORDER BY id"
cursor.execute(query)
sensitive_records = cursor.fetchall()
@ -283,7 +283,7 @@ def get_all_sensitive_words() -> list[str]:
except MySQLError as e:
# 数据库错误处理
raise MySQLError(f"查询敏感词失败{str(e)}") from e
raise MySQLError(f"查询敏感词失败: {str(e)}") from e
finally:
# 确保资源正确释放
db.close_connection(conn, cursor)

View File

@ -27,7 +27,7 @@ router = APIRouter(
@router.post("/register", response_model=APIResponse, summary="用户注册")
async def user_register(request: UserRegisterRequest):
"""
用户注册
用户注册:
- 校验用户名是否已存在
- 加密密码后插入数据库
- 返回注册成功信息
@ -67,7 +67,7 @@ async def user_register(request: UserRegisterRequest):
)
except MySQLError as e:
conn.rollback() # 数据库错误时回滚事务
raise Exception(f"注册失败{str(e)}") from e
raise Exception(f"注册失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@ -78,7 +78,7 @@ async def user_register(request: UserRegisterRequest):
@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token")
async def user_login(request: UserLoginRequest):
"""
用户登录
用户登录:
- 校验用户名是否存在
- 校验密码是否正确
- 生成 JWT Token 并返回
@ -89,7 +89,7 @@ async def user_login(request: UserLoginRequest):
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 修复SQL查询添加 created_at 和 updated_at 字段
# 修复: SQL查询添加 created_at 和 updated_at 字段
query = """
SELECT id, username, password, created_at, updated_at
FROM users
@ -129,7 +129,7 @@ async def user_login(request: UserLoginRequest):
}
)
except MySQLError as e:
raise Exception(f"登录失败{str(e)}") from e
raise Exception(f"登录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@ -142,8 +142,8 @@ async def get_current_user_info(
current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件
):
"""
获取当前登录用户信息
- 需在请求头携带 Token格式Bearer <token>
获取当前登录用户信息:
- 需在请求头携带 Token格式: Bearer <token>
- 认证通过后返回用户信息
"""
return APIResponse(

View File

@ -27,7 +27,7 @@ def init_insightface():
def add_binary_data(binary_data):
"""
接收单张图片的二进制数据提取特征并保存
接收单张图片的二进制数据提取特征并保存
参数:
binary_data: 图片的二进制数据bytes类型
@ -39,11 +39,11 @@ def add_binary_data(binary_data):
global _insightface_app, _feature_list
if not _insightface_app:
print("引擎未初始化无法处理")
print("引擎未初始化无法处理")
return False, None
try:
# 直接处理二进制数据转换为图像格式
# 直接处理二进制数据: 转换为图像格式
img = Image.open(BytesIO(binary_data))
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
@ -70,15 +70,15 @@ def get_average_feature(features=None):
计算多个特征向量的平均值
参数:
features: 可选特征值列表。如果未提供则使用全局存储的_feature_list
features: 可选特征值列表。如果未提供则使用全局存储的_feature_list
每个元素可以是字符串格式或numpy数组
返回:
单一平均特征向量的numpy数组若无可计算数据则返回None
单一平均特征向量的numpy数组若无可计算数据则返回None
"""
global _feature_list
# 如果未提供features参数则使用全局特征列表
# 如果未提供features参数则使用全局特征列表
if features is None:
features = _feature_list
@ -105,7 +105,7 @@ def get_average_feature(features=None):
processed_features.append(embedding_np)
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
else:
print(f"跳过第 {i + 1} 个特征值不是一维数组")
print(f"跳过第 {i + 1} 个特征值不是一维数组")
except Exception as e:
print(f"处理第 {i + 1} 个特征值时出错: {e}")
@ -118,12 +118,12 @@ def get_average_feature(features=None):
# 检查所有特征向量维度是否相同
dims = {feat.shape[0] for feat in processed_features}
if len(dims) > 1:
print(f"特征值维度不一致无法计算平均值。检测到的维度: {dims}")
print(f"特征值维度不一致无法计算平均值。检测到的维度: {dims}")
return None
# 计算平均值
avg_feature = np.mean(processed_features, axis=0)
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量维度: {avg_feature.shape[0]}")
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量维度: {avg_feature.shape[0]}")
return avg_feature

View File

@ -21,7 +21,7 @@ WS_ENDPOINT = "/ws" # WebSocket端点路径
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
# 工具函数获取格式化时间字符串(统一时间戳格式)
# 工具函数: 获取格式化时间字符串(统一时间戳格式)
def get_current_time_str() -> str:
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@ -65,32 +65,32 @@ class ClientConnection:
"client_ip": self.client_ip
}
await self.websocket.send_json(frame_permit_msg)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}已发送帧发送许可信号(取帧后立即通知)")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送帧发送许可信号(取帧后立即通知)")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}帧许可信号发送失败 - {str(e)}")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可信号发送失败 - {str(e)}")
async def consume_frames(self) -> None:
"""消费队列中的帧并处理(核心调整取帧后立即发许可再处理帧)"""
"""消费队列中的帧并处理(核心调整: 取帧后立即发许可再处理帧)"""
try:
while True:
# 1. 从队列取出帧(阻塞直到有帧可用)
frame_data = await self.frame_queue.get()
# -------------------------- 核心修改取出帧后立即发送下一帧许可 --------------------------
await self.send_frame_permit() # 取帧即通知客户端发下一帧无需等处理完成
# -------------------------- 核心修改: 取出帧后立即发送下一帧许可 --------------------------
await self.send_frame_permit() # 取帧即通知客户端发下一帧无需等处理完成
# -----------------------------------------------------------------------------------------
try:
# 2. 处理取出的帧(即使处理慢客户端也已收到许可可提前准备下一帧)
# 2. 处理取出的帧(即使处理慢客户端也已收到许可可提前准备下一帧)
await self.process_frame(frame_data)
finally:
# 3. 标记帧任务完成(无论处理成功/失败都需清理队列)
# 3. 标记帧任务完成(无论处理成功/失败都需清理队列)
self.frame_queue.task_done()
except asyncio.CancelledError:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}帧消费任务已取消")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费任务已取消")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}帧消费逻辑错误 - {str(e)}")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧图像数据(检测违规后发送危险通知 + 调用违规次数加一方法)"""
@ -98,31 +98,31 @@ class ClientConnection:
nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}无法解析图像数据")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析图像数据")
return
# 确保图像保存目录存在
os.makedirs('images', exist_ok=True)
# 保存图像按IP+时间戳命名避免冲突)
# 保存图像按IP+时间戳命名避免冲突)
filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg"
try:
cv2.imwrite(filename, img)
print(f"[{get_current_time_str()}] 图像已保存至{filename}")
print(f"[{get_current_time_str()}] 图像已保存至: {filename}")
has_violation, data, type = detect(img)
print(has_violation)
print(type)
print(data)
if has_violation:
print(
f"[{get_current_time_str()}] 客户端{self.client_ip}检测到违规 - 类型: {type}, 详情: {data}")
f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - 类型: {type}, 详情: {data}")
# 调用违规次数加一方法
try:
await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}违规次数已+1")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数已+1")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}违规次数更新失败 - {str(e)}")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}")
# 发送「危险通知」
danger_msg = {
@ -132,9 +132,9 @@ class ClientConnection:
}
await self.websocket.send_json(danger_msg)
else:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}未检测到违规")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 未检测到违规")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}图像处理错误 - {str(e)}")
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}")
# 全局状态管理
@ -149,7 +149,7 @@ async def heartbeat_checker():
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
if timeout_ips:
print(f"[{current_time}] 心跳检查{len(timeout_ips)}个客户端超时IP{timeout_ips}")
print(f"[{current_time}] 心跳检查: {len(timeout_ips)}个客户端超时IP: {timeout_ips}")
for ip in timeout_ips:
try:
conn = connected_clients[ip]
@ -162,13 +162,13 @@ async def heartbeat_checker():
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
action_data = DeviceActionCreate(client_ip=ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{ip}已标记为离线并记录操作")
print(f"[{current_time}] 客户端{ip}: 已标记为离线并记录操作")
except Exception as e:
print(f"[{current_time}] 客户端{ip}离线状态更新失败 - {str(e)}")
print(f"[{current_time}] 客户端{ip}: 离线状态更新失败 - {str(e)}")
finally:
connected_clients.pop(ip, None)
else:
print(f"[{current_time}] 心跳检查{len(connected_clients)}个客户端在线")
print(f"[{current_time}] 心跳检查: {len(connected_clients)}个客户端在线")
await asyncio.sleep(HEARTBEAT_INTERVAL)
@ -178,7 +178,7 @@ async def heartbeat_checker():
async def lifespan(app: FastAPI):
global heartbeat_task
heartbeat_task = asyncio.create_task(heartbeat_checker())
print(f"[{get_current_time_str()}] 全局心跳检查任务启动任务ID{id(heartbeat_task)}")
print(f"[{get_current_time_str()}] 全局心跳检查任务启动任务ID: {id(heartbeat_task)}")
yield
if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel()
@ -198,11 +198,11 @@ async def send_heartbeat_ack(conn: ClientConnection):
"client_ip": conn.client_ip
}
await conn.websocket.send_json(heartbeat_ack_msg)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}已发送心跳确认")
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送心跳确认")
return True
except Exception as e:
connected_clients.pop(conn.client_ip, None)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}心跳确认发送失败 - {str(e)}")
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 心跳确认发送失败 - {str(e)}")
return False
@ -213,17 +213,17 @@ async def handle_text_msg(conn: ClientConnection, text: str):
conn.update_heartbeat()
await send_heartbeat_ack(conn)
else:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}未知文本消息类型({msg.get('type')}")
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知文本消息类型({msg.get('type')}")
except json.JSONDecodeError:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}无效JSON文本消息")
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON文本消息")
async def handle_binary_msg(conn: ClientConnection, data: bytes):
try:
conn.frame_queue.put_nowait(data)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}图像数据({len(data)}字节)已加入队列")
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 图像数据({len(data)}字节)已加入队列")
except asyncio.QueueFull:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}帧队列已满丢弃当前图像数据")
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 帧队列已满丢弃当前图像数据")
# WebSocket路由配置
@ -237,7 +237,7 @@ async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown_ip"
current_time = get_current_time_str()
print(f"[{current_time}] 客户端{client_ip}WebSocket连接已建立")
print(f"[{current_time}] 客户端{client_ip}: WebSocket连接已建立")
is_online_updated = False
@ -249,13 +249,13 @@ async def websocket_endpoint(websocket: WebSocket):
old_conn.consumer_task.cancel()
await old_conn.websocket.close(code=1008, reason="同一IP新连接建立")
connected_clients.pop(client_ip)
print(f"[{current_time}] 客户端{client_ip}已关闭旧连接")
print(f"[{current_time}] 客户端{client_ip}: 已关闭旧连接")
# 注册新连接
new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn
new_conn.start_consumer()
# 初始许可连接建立后立即发一次让客户端知道可发第一帧(后续靠取帧后自动发)
# 初始许可: 连接建立后立即发一次让客户端知道可发第一帧(后续靠取帧后自动发)
await new_conn.send_frame_permit()
# 标记上线并记录
@ -263,12 +263,12 @@ async def websocket_endpoint(websocket: WebSocket):
await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
action_data = DeviceActionCreate(client_ip=client_ip, action=1)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{client_ip}已标记为在线并记录操作")
print(f"[{current_time}] 客户端{client_ip}: 已标记为在线并记录操作")
is_online_updated = True
except Exception as e:
print(f"[{current_time}] 客户端{client_ip}上线状态更新失败 - {str(e)}")
print(f"[{current_time}] 客户端{client_ip}: 上线状态更新失败 - {str(e)}")
print(f"[{current_time}] 客户端{client_ip}新连接注册成功在线数{len(connected_clients)}")
print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功在线数: {len(connected_clients)}")
# 消息循环
while True:
@ -279,9 +279,9 @@ async def websocket_endpoint(websocket: WebSocket):
await handle_binary_msg(new_conn, data["bytes"])
except WebSocketDisconnect as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}主动断开连接(代码{e.code}")
print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code}")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}连接异常 - {str(e)[:50]}")
print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}")
finally:
# 清理资源并标记离线
if client_ip in connected_clients:
@ -295,9 +295,9 @@ async def websocket_endpoint(websocket: WebSocket):
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
action_data = DeviceActionCreate(client_ip=client_ip, action=0)
await asyncio.to_thread(add_device_action, action_data)
print(f"[{get_current_time_str()}] 客户端{client_ip}断开后已标记为离线")
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后已标记为离线")
except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}断开后离线更新失败 - {str(e)}")
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后离线更新失败 - {str(e)}")
connected_clients.pop(client_ip, None)
print(f"[{get_current_time_str()}] 客户端{client_ip}资源已清理在线数{len(connected_clients)}")
print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理在线数: {len(connected_clients)}")