优化代码风格

This commit is contained in:
ZZX9599
2025-09-08 17:34:23 +08:00
parent 9b3d20511a
commit 8ceb92c572
20 changed files with 223 additions and 192 deletions

View File

@ -12,8 +12,4 @@ charset = utf8mb4
[jwt] [jwt]
secret_key = 6tsieyd87wefdw2wgeduwte23rfcsd secret_key = 6tsieyd87wefdw2wgeduwte23rfcsd
algorithm = HS256 algorithm = HS256
access_token_expire_minutes = 30 access_token_expire_minutes = 30
[live]
rtmp_url = rtmp://192.168.110.25:1935/live/
webrtc_url = http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=

View File

@ -22,7 +22,7 @@ _task_counter_lock = threading.Lock() # 任务计数锁
# -------------------------- 工具函数 -------------------------- # -------------------------- 工具函数 --------------------------
def _get_next_task_id(): def _get_next_task_id():
"""获取唯一任务ID用于日志追踪""" """获取唯一任务ID用于日志追踪"""
global _task_counter global _task_counter
with _task_counter_lock: with _task_counter_lock:
_task_counter += 1 _task_counter += 1
@ -65,11 +65,11 @@ def _init_thread_pool():
max_workers=MAX_WORKERS, max_workers=MAX_WORKERS,
thread_name_prefix="DetectionThread" thread_name_prefix="DetectionThread"
) )
print(f"=== 线程池初始化完成最大线程数: {MAX_WORKERS} ===") print(f"=== 线程池初始化完成最大线程数: {MAX_WORKERS} ===")
def shutdown(): def shutdown():
"""关闭线程池释放资源""" """关闭线程池释放资源"""
global _executor global _executor
with _executor_lock: with _executor_lock:
if _executor is not None: if _executor is not None:
@ -82,7 +82,7 @@ def shutdown():
def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple: def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple:
"""在子线程中执行检测逻辑""" """在子线程中执行检测逻辑"""
thread_name = threading.current_thread().name thread_name = threading.current_thread().name
print(f"任务[{task_id}] 开始执行线程: {thread_name}") print(f"任务[{task_id}] 开始执行线程: {thread_name}")
try: try:
# 按照优先级执行检测 # 按照优先级执行检测
@ -98,7 +98,7 @@ def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple:
print(f"任务[{task_id}] {detector}检测结果: {'成功' if flag else '失败'}") print(f"任务[{task_id}] {detector}检测结果: {'成功' if flag else '失败'}")
if flag: if flag:
print(f"任务[{task_id}] 完成检测使用检测器: {detector}") print(f"任务[{task_id}] 完成检测使用检测器: {detector}")
return (True, result, detector, task_id) return (True, result, detector, task_id)
# 所有检测器均未检测到结果 # 所有检测器均未检测到结果
@ -116,14 +116,14 @@ def detect(frame: np.ndarray) -> Future:
提交检测任务到线程池 提交检测任务到线程池
参数: 参数:
frame: 待检测图像(ndarray格式cv2.imdecode生成) frame: 待检测图像(ndarray格式cv2.imdecode生成)
返回: 返回:
Future对象通过result()方法获取检测结果 Future对象通过result()方法获取检测结果
""" """
# 确保模型已加载 # 确保模型已加载
if not _model_loaded: if not _model_loaded:
print("警告: 模型尚未加载将自动加载") print("警告: 模型尚未加载将自动加载")
load_model() load_model()
# 生成任务ID # 生成任务ID
@ -131,6 +131,6 @@ def detect(frame: np.ndarray) -> Future:
# 提交任务到线程池 # 提交任务到线程池
future = _executor.submit(_detect_in_thread, frame, task_id) future = _executor.submit(_detect_in_thread, frame, task_id)
print(f"任务[{task_id}] 已提交到线程池") print(f"任务[{task_id}]: 已提交到线程池")
return future return future

View File

@ -16,7 +16,7 @@ try:
pynvml.nvmlInit() pynvml.nvmlInit()
_nvml_available = True _nvml_available = True
except ImportError: except ImportError:
print("警告: pynvml库未安装无法检测GPU状态将默认使用0号GPU") print("警告: pynvml库未安装无法检测GPU状态将默认使用0号GPU")
_nvml_available = False _nvml_available = False
# 全局变量 # 全局变量
@ -58,7 +58,7 @@ def check_gpu_availability(gpu_id, threshold=0.7):
def select_best_gpu(preferred_gpus=[0, 1]): def select_best_gpu(preferred_gpus=[0, 1]):
"""选择最佳可用GPU严格按照首选列表顺序检查优先使用0号GPU""" """选择最佳可用GPU严格按照首选列表顺序检查优先使用0号GPU"""
# 首先检查首选GPU列表 # 首先检查首选GPU列表
for gpu_id in preferred_gpus: for gpu_id in preferred_gpus:
try: try:
@ -68,17 +68,17 @@ def select_best_gpu(preferred_gpus=[0, 1]):
# 检查GPU是否可用 # 检查GPU是否可用
if check_gpu_availability(gpu_id): if check_gpu_availability(gpu_id):
print(f"GPU {gpu_id} 可用将使用该GPU") print(f"GPU {gpu_id} 可用将使用该GPU")
return gpu_id return gpu_id
else: else:
if gpu_id == 0: if gpu_id == 0:
print(f"GPU 0 内存使用率过高(繁忙)尝试切换到其他GPU") print(f"GPU 0 内存使用率过高(繁忙)尝试切换到其他GPU")
except Exception as e: except Exception as e:
print(f"GPU {gpu_id} 不存在或无法访问: {e}") print(f"GPU {gpu_id} 不存在或无法访问: {e}")
continue continue
# 如果所有首选GPU都不可用返回-1表示使用CPU # 如果所有首选GPU都不可用返回-1表示使用CPU
print("所有指定的GPU都不可用将使用CPU进行计算") print("所有指定的GPU都不可用将使用CPU进行计算")
return -1 return -1
@ -122,12 +122,12 @@ def _release_engine():
def _monitor_thread(): def _monitor_thread():
"""监控线程检查并释放超时未使用的资源""" """监控线程检查并释放超时未使用的资源"""
global _ref_count, _last_used_time, _face_app global _ref_count, _last_used_time, _face_app
while True: while True:
time.sleep(5) # 每5秒检查一次 time.sleep(5) # 每5秒检查一次
with _lock: with _lock:
# 只有当引擎存在、没有引用且超时才释放 # 只有当引擎存在、没有引用且超时才释放
if _face_app and _ref_count == 0 and not _is_releasing: if _face_app and _ref_count == 0 and not _is_releasing:
elapsed = time.time() - _last_used_time elapsed = time.time() - _last_used_time
if elapsed > _release_timeout: if elapsed > _release_timeout:
@ -136,7 +136,7 @@ def _monitor_thread():
def load_model(prefer_gpu=True, preferred_gpus=[0, 1]): def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
"""加载人脸识别模型及已知人脸特征库默认优先使用0号GPU""" """加载人脸识别模型及已知人脸特征库默认优先使用0号GPU"""
global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id
# 确保监控线程只启动一次 # 确保监控线程只启动一次
@ -144,11 +144,11 @@ def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
threading.Thread(target=_monitor_thread, daemon=True, name="FaceMonitor").start() threading.Thread(target=_monitor_thread, daemon=True, name="FaceMonitor").start()
print("Face monitor thread started") print("Face monitor thread started")
# 如果正在释放中等待释放完成 # 如果正在释放中等待释放完成
while _is_releasing: while _is_releasing:
time.sleep(0.1) time.sleep(0.1)
# 如果已经初始化直接返回 # 如果已经初始化直接返回
if _face_app: if _face_app:
return True return True
@ -158,7 +158,7 @@ def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
print("正在初始化InsightFace人脸识别引擎...") print("正在初始化InsightFace人脸识别引擎...")
_face_app = FaceAnalysis(name='buffalo_l', root='~/.insightface') _face_app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
# 选择合适的GPU默认优先使用0号 # 选择合适的GPU默认优先使用0号
ctx_id = 0 ctx_id = 0
if prefer_gpu: if prefer_gpu:
ctx_id = select_best_gpu(preferred_gpus) ctx_id = select_best_gpu(preferred_gpus)
@ -166,9 +166,9 @@ def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
_used_gpu_id = ctx_id if _using_gpu else -1 _used_gpu_id = ctx_id if _using_gpu else -1
if _using_gpu: if _using_gpu:
print(f"成功初始化使用GPU {ctx_id} 进行计算") print(f"成功初始化使用GPU {ctx_id} 进行计算")
else: else:
print("成功初始化使用CPU进行计算") print("成功初始化使用CPU进行计算")
# 准备模型 # 准备模型
_face_app.prepare(ctx_id=ctx_id, det_size=(640, 640)) _face_app.prepare(ctx_id=ctx_id, det_size=(640, 640))
@ -188,10 +188,10 @@ def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
for person_name, eigenvalue_data in face_data.items(): for person_name, eigenvalue_data in face_data.items():
# 处理特征值数据 - 兼容数组和字符串两种格式 # 处理特征值数据 - 兼容数组和字符串两种格式
if isinstance(eigenvalue_data, np.ndarray): if isinstance(eigenvalue_data, np.ndarray):
# 如果已经是numpy数组直接使用 # 如果已经是numpy数组直接使用
eigenvalue = eigenvalue_data.astype(np.float32) eigenvalue = eigenvalue_data.astype(np.float32)
elif isinstance(eigenvalue_data, str): elif isinstance(eigenvalue_data, str):
# 清理字符串移除方括号、换行符和多余空格 # 清理字符串: 移除方括号、换行符和多余空格
cleaned = eigenvalue_data.replace('[', '').replace(']', '').replace('\n', '').strip() cleaned = eigenvalue_data.replace('[', '').replace(']', '').replace('\n', '').strip()
# 按空格或逗号分割(处理可能的不同分隔符) # 按空格或逗号分割(处理可能的不同分隔符)
values = [v for v in cleaned.split() if v] values = [v for v in cleaned.split() if v]
@ -217,7 +217,7 @@ def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
def detect(frame, threshold=0.4): def detect(frame, threshold=0.4):
"""检测并识别人脸返回结果元组(是否匹配到已知人脸, 结果字符串)""" """检测并识别人脸返回结果元组(是否匹配到已知人脸, 结果字符串)"""
global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id
global _ref_count, _last_used_time global _ref_count, _last_used_time
@ -248,7 +248,7 @@ def detect(frame, threshold=0.4):
return (False, "人脸识别引擎不可用或未初始化") return (False, "人脸识别引擎不可用或未初始化")
try: try:
# 如果使用GPU确保输入帧在处理前是连续的数组 # 如果使用GPU确保输入帧在处理前是连续的数组
if _using_gpu and not frame.flags.contiguous: if _using_gpu and not frame.flags.contiguous:
frame = np.ascontiguousarray(frame) frame = np.ascontiguousarray(frame)
@ -285,7 +285,7 @@ def detect(frame, threshold=0.4):
# 判断匹配结果 # 判断匹配结果
is_match = max_sim >= threshold is_match = max_sim >= threshold
if is_match: if is_match:
has_matched = True # 只要有一个匹配成功就标记为True has_matched = True # 只要有一个匹配成功就标记为True
bbox = face.bbox bbox = face.bbox
result_parts.append( result_parts.append(
@ -298,12 +298,12 @@ def detect(frame, threshold=0.4):
else: else:
result_str = "; ".join(result_parts) result_str = "; ".join(result_parts)
# 减少引用计数确保线程安全 # 减少引用计数确保线程安全
with _lock: with _lock:
_ref_count = max(0, _ref_count - 1) _ref_count = max(0, _ref_count - 1)
# 持续使用时更新最后使用时间 # 持续使用时更新最后使用时间
if _ref_count > 0: if _ref_count > 0:
_last_used_time = time.time() _last_used_time = time.time()
# 第一个返回值为是否匹配到已知人脸 # 第一个返回值为: 是否匹配到已知人脸
return (has_matched, result_str) return (has_matched, result_str)

View File

@ -61,12 +61,12 @@ def _release_engine():
def _monitor_thread(): def _monitor_thread():
"""监控线程优化检查逻辑""" """监控线程优化检查逻辑"""
global _ref_count, _last_used_time, _ocr_engine global _ref_count, _last_used_time, _ocr_engine
while True: while True:
time.sleep(5) # 每5秒检查一次 time.sleep(5) # 每5秒检查一次
with _lock: with _lock:
# 只有当引擎存在、没有引用且超时才释放 # 只有当引擎存在、没有引用且超时才释放
if _ocr_engine and _ref_count == 0 and not _is_releasing: if _ocr_engine and _ref_count == 0 and not _is_releasing:
elapsed = time.time() - _last_used_time elapsed = time.time() - _last_used_time
if elapsed > _release_timeout: if elapsed > _release_timeout:
@ -100,7 +100,7 @@ def load_model():
def detect(frame): def detect(frame):
"""OCR检测优化引用计数管理""" """OCR检测优化引用计数管理"""
global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time
# 验证前置条件 # 验证前置条件
@ -178,7 +178,7 @@ def detect(frame):
return (False, f"检测错误: {str(e)}") return (False, f"检测错误: {str(e)}")
finally: finally:
# 减少引用计数确保线程安全 # 减少引用计数确保线程安全
with _lock: with _lock:
_ref_count = max(0, _ref_count - 1) _ref_count = max(0, _ref_count - 1)
# 持续使用时更新最后使用时间 # 持续使用时更新最后使用时间

View File

@ -1,6 +1,5 @@
import os import os
import cv2
from ultralytics import YOLO from ultralytics import YOLO
# 全局变量 # 全局变量
@ -24,7 +23,7 @@ def load_model():
def detect(frame, conf_threshold=0.2): def detect(frame, conf_threshold=0.2):
"""YOLO目标检测返回(是否识别到, 结果字符串)""" """YOLO目标检测返回(是否识别到, 结果字符串)"""
global _yolo_model global _yolo_model
if not _yolo_model or frame is None: if not _yolo_model or frame is None:

View File

@ -14,4 +14,3 @@ config.read(config_path, encoding="utf-8")
SERVER_CONFIG = config["server"] SERVER_CONFIG = config["server"]
MYSQL_CONFIG = config["mysql"] MYSQL_CONFIG = config["mysql"]
JWT_CONFIG = config["jwt"] JWT_CONFIG = config["jwt"]
LIVE_CONFIG = config["live"]

View File

@ -47,7 +47,7 @@ if __name__ == "__main__":
YOLO_MODEL_PATH = r"/core/models\best.pt" YOLO_MODEL_PATH = r"/core/models\best.pt"
OCR_CONFIG_PATH = r"/core/config\config.yaml" OCR_CONFIG_PATH = r"/core/config\config.yaml"
# 初始化项目默认端口设为8000避免初始化失败时port未定义 # 初始化项目默认端口设为8000避免初始化失败时port未定义
port = int(SERVER_CONFIG.get("port", 8000)) port = int(SERVER_CONFIG.get("port", 8000))
# 启动 UVicorn 服务 # 启动 UVicorn 服务

View File

@ -9,8 +9,6 @@ from passlib.context import CryptContext
from ds.config import JWT_CONFIG from ds.config import JWT_CONFIG
from ds.db import db from ds.db import db
# 移除这里的 from service.user_service import UserResponse 导入
# ------------------------------ # ------------------------------
# 密码加密配置 # 密码加密配置
# ------------------------------ # ------------------------------
@ -23,7 +21,7 @@ SECRET_KEY = JWT_CONFIG["secret_key"]
ALGORITHM = JWT_CONFIG["algorithm"] ALGORITHM = JWT_CONFIG["algorithm"]
ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"]) ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
# OAuth2 依赖(从请求头获取 Token、格式Bearer <token> # OAuth2 依赖(从请求头获取 Token、格式: Bearer <token>
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
@ -63,7 +61,7 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
# ------------------------------ # ------------------------------
def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型注解 def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型注解
"""从 Token 中解析用户信息、验证通过后返回当前用户""" """从 Token 中解析用户信息、验证通过后返回当前用户"""
# 延迟导入打破循环依赖 # 延迟导入打破循环依赖
from schema.user_schema import UserResponse # 在这里导入 from schema.user_schema import UserResponse # 在这里导入
# 认证失败异常 # 认证失败异常
@ -101,4 +99,4 @@ def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型
except Exception as e: except Exception as e:
raise credentials_exception from e raise credentials_exception from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)

View File

@ -8,7 +8,7 @@ from schema.response_schema import APIResponse
async def global_exception_handler(request: Request, exc: Exception): async def global_exception_handler(request: Request, exc: Exception):
"""全局异常处理器所有未捕获的异常都会在这里统一处理""" """全局异常处理器: 所有未捕获的异常都会在这里统一处理"""
# 1. 请求参数验证错误Pydantic 校验失败) # 1. 请求参数验证错误Pydantic 校验失败)
if isinstance(exc, RequestValidationError): if isinstance(exc, RequestValidationError):
error_details = [] error_details = []
@ -18,7 +18,7 @@ async def global_exception_handler(request: Request, exc: Exception):
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
content=APIResponse( content=APIResponse(
code=400, code=400,
message=f"请求参数错误{'; '.join(error_details)}", message=f"请求参数错误: {'; '.join(error_details)}",
data=None data=None
).model_dump() ).model_dump()
) )
@ -52,7 +52,7 @@ async def global_exception_handler(request: Request, exc: Exception):
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=APIResponse( content=APIResponse(
code=500, code=500,
message=f"数据库错误{str(exc)}", message=f"数据库错误: {str(exc)}",
data=None data=None
).model_dump() ).model_dump()
) )
@ -62,7 +62,7 @@ async def global_exception_handler(request: Request, exc: Exception):
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=APIResponse( content=APIResponse(
code=500, code=500,
message=f"服务器内部错误{str(exc)}", message=f"服务器内部错误: {str(exc)}",
data=None data=None
).model_dump() ).model_dump()
) )

View File

@ -4,12 +4,12 @@ from pydantic import BaseModel, Field
# ------------------------------ # ------------------------------
# 请求模型(新增记录用,极简) # 请求模型
# ------------------------------ # ------------------------------
class DeviceActionCreate(BaseModel): class DeviceActionCreate(BaseModel):
"""设备操作记录创建模型0=离线1=上线)""" """设备操作记录创建模型0=离线1=上线)"""
client_ip: str = Field(..., description="客户端IP") client_ip: str = Field(..., description="客户端IP")
action: int = Field(..., ge=0, le=1, description="操作状态0=离线1=上线)") action: int = Field(..., ge=0, le=1, description="操作状态0=离线1=上线)")
# ------------------------------ # ------------------------------
@ -19,7 +19,7 @@ class DeviceActionResponse(BaseModel):
"""设备操作记录响应模型(与自增表对齐)""" """设备操作记录响应模型(与自增表对齐)"""
id: int = Field(..., description="自增主键ID") id: int = Field(..., description="自增主键ID")
client_ip: Optional[str] = Field(None, description="客户端IP") client_ip: Optional[str] = Field(None, description="客户端IP")
action: Optional[int] = Field(None, description="操作状态0=离线1=上线)") action: Optional[int] = Field(None, description="操作状态0=离线1=上线)")
created_at: datetime = Field(..., description="记录创建时间") created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间") updated_at: datetime = Field(..., description="记录更新时间")

View File

@ -6,8 +6,8 @@ from pydantic import BaseModel, Field
# 请求模型(前端传参校验) # 请求模型(前端传参校验)
# ------------------------------ # ------------------------------
class FaceCreateRequest(BaseModel): class FaceCreateRequest(BaseModel):
"""创建人脸记录请求模型无需ID由数据库自增)""" """创建人脸记录请求模型无需ID由数据库自增)"""
name: str = Field(None, max_length=255, description="名称(可选最长255字符") name: str = Field(None, max_length=255, description="名称(可选最长255字符")
class FaceUpdateRequest(BaseModel): class FaceUpdateRequest(BaseModel):
@ -20,7 +20,7 @@ class FaceUpdateRequest(BaseModel):
# 响应模型(后端返回数据) # 响应模型(后端返回数据)
# ------------------------------ # ------------------------------
class FaceResponse(BaseModel): class FaceResponse(BaseModel):
"""人脸记录响应模型仍包含ID由数据库生成后返回)""" """人脸记录响应模型仍包含ID由数据库生成后返回)"""
id: int = Field(..., description="主键ID数据库自增") id: int = Field(..., description="主键ID数据库自增")
name: str = Field(None, description="名称") name: str = Field(None, description="名称")
eigenvalue: str | None = Field(None, description="特征(可为空)") eigenvalue: str | None = Field(None, description="特征(可为空)")

View File

@ -5,9 +5,9 @@ from pydantic import BaseModel, Field
class APIResponse(BaseModel): class APIResponse(BaseModel):
"""统一 API 响应模型(所有接口必返此格式)""" """统一 API 响应模型(所有接口必返此格式)"""
code: int = Field(..., description="状态码200=成功、4xx=客户端错误、5xx=服务端错误") code: int = Field(..., description="状态码: 200=成功、4xx=客户端错误、5xx=服务端错误")
message: str = Field(..., description="响应信息成功/错误描述") message: str = Field(..., description="响应信息: 成功/错误描述")
data: Optional[Any] = Field(None, description="响应数据成功时返回、错误时为 None") data: Optional[Any] = Field(None, description="响应数据: 成功时返回、错误时为 None")
# Pydantic V2 配置(支持从 ORM 对象转换) # Pydantic V2 配置(支持从 ORM 对象转换)
model_config = {"from_attributes": True} model_config = {"from_attributes": True}

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
# ------------------------------ # ------------------------------
class SensitiveCreateRequest(BaseModel): class SensitiveCreateRequest(BaseModel):
"""创建敏感信息记录请求模型""" """创建敏感信息记录请求模型"""
# 移除了id字段由数据库自动生成 # 移除了id字段由数据库自动生成
name: str = Field(None, max_length=255, description="名称") name: str = Field(None, max_length=255, description="名称")

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, Query from fastapi import APIRouter, Query, Path
from mysql.connector import Error as MySQLError from mysql.connector import Error as MySQLError
from ds.db import db from ds.db import db
@ -9,7 +9,6 @@ from schema.device_action_schema import (
) )
from schema.response_schema import APIResponse from schema.response_schema import APIResponse
# 路由配置 # 路由配置
router = APIRouter( router = APIRouter(
prefix="/device/actions", prefix="/device/actions",
@ -18,11 +17,11 @@ router = APIRouter(
# ------------------------------ # ------------------------------
# 内部方法新增设备操作记录适配id自增 # 内部方法: 新增设备操作记录适配id自增
# ------------------------------ # ------------------------------
def add_device_action(action_data: DeviceActionCreate) -> DeviceActionResponse: def add_device_action(action_data: DeviceActionCreate) -> DeviceActionResponse:
""" """
新增设备操作记录(内部方法非接口) 新增设备操作记录(内部方法非接口)
:param action_data: 含client_ip和action0/1 :param action_data: 含client_ip和action0/1
:return: 新增的完整记录 :return: 新增的完整记录
""" """
@ -32,7 +31,7 @@ def add_device_action(action_data: DeviceActionCreate) -> DeviceActionResponse:
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 插入SQLid自增依赖数据库自动生成) # 插入SQLid自增依赖数据库自动生成)
insert_query = """ insert_query = """
INSERT INTO device_action INSERT INTO device_action
(client_ip, action, created_at, updated_at) (client_ip, action, created_at, updated_at)
@ -54,20 +53,20 @@ def add_device_action(action_data: DeviceActionCreate) -> DeviceActionResponse:
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"新增记录失败{str(e)}") from e raise Exception(f"新增记录失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# ------------------------------ # ------------------------------
# 接口分页查询操作记录列表(仅返回 total + device_actions # 接口: 分页查询操作记录列表(仅返回 total + device_actions
# ------------------------------ # ------------------------------
@router.get("/list", response_model=APIResponse, summary="分页查询设备操作记录") @router.get("/list", response_model=APIResponse, summary="分页查询设备操作记录")
async def get_device_action_list( async def get_device_action_list(
page: int = Query(1, ge=1, description="页码默认1"), page: int = Query(1, ge=1, description="页码默认1"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100"), page_size: int = Query(10, ge=1, le=100, description="每页条数1-100"),
client_ip: str = Query(None, description="按客户端IP筛选"), client_ip: str = Query(None, description="按客户端IP筛选"),
action: int = Query(None, ge=0, le=1, description="按状态筛选0=离线1=上线)") action: int = Query(None, ge=0, le=1, description="按状态筛选0=离线1=上线)")
): ):
conn = None conn = None
cursor = None cursor = None
@ -75,7 +74,7 @@ async def get_device_action_list(
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 1. 构建筛选条件(参数化查询避免注入) # 1. 构建筛选条件(参数化查询避免注入)
where_clause = [] where_clause = []
params = [] params = []
if client_ip: if client_ip:
@ -92,13 +91,13 @@ async def get_device_action_list(
cursor.execute(count_sql, params) cursor.execute(count_sql, params)
total = cursor.fetchone()["total"] total = cursor.fetchone()["total"]
# 3. 分页查询记录(按创建时间倒序确保最新记录在前) # 3. 分页查询记录(按创建时间倒序确保最新记录在前)
offset = (page - 1) * page_size offset = (page - 1) * page_size
list_sql = "SELECT * FROM device_action" list_sql = "SELECT * FROM device_action"
if where_clause: if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause) list_sql += " WHERE " + " AND ".join(where_clause)
list_sql += " ORDER BY created_at DESC LIMIT %s OFFSET %s" list_sql += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset]) # 追加分页参数page/page_size仅用于查询不返回) params.extend([page_size, offset]) # 追加分页参数page/page_size仅用于查询不返回)
cursor.execute(list_sql, params) cursor.execute(list_sql, params)
action_list = cursor.fetchall() action_list = cursor.fetchall()
@ -114,6 +113,46 @@ async def get_device_action_list(
) )
except MySQLError as e: except MySQLError as e:
raise Exception(f"查询记录失败{str(e)}") from e raise Exception(f"查询记录失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@router.get("/{client_ip}", response_model=APIResponse, summary="根据IP查询设备操作记录")
async def get_device_actions_by_ip(
client_ip: str = Path(..., description="客户端IP地址")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 查询总记录数
count_sql = "SELECT COUNT(*) AS total FROM device_action WHERE client_ip = %s"
cursor.execute(count_sql, (client_ip,))
total = cursor.fetchone()["total"]
# 2. 查询该IP的所有记录按创建时间倒序
list_sql = """
SELECT * FROM device_action
WHERE client_ip = %s
ORDER BY created_at DESC
"""
cursor.execute(list_sql, (client_ip,))
action_list = cursor.fetchall()
# 3. 返回结果
return APIResponse(
code=200,
message="查询成功",
data=DeviceActionListResponse(
total=total,
device_actions=[DeviceActionResponse(**item) for item in action_list]
)
)
except MySQLError as e:
raise Exception(f"查询记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

View File

@ -37,7 +37,7 @@ def increment_alarm_count_by_ip(client_ip: str) -> bool:
if not cursor.fetchone(): if not cursor.fetchone():
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在") raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
# 报警次数加1并更新时间戳 # 报警次数加1并更新时间戳
update_query = """ update_query = """
UPDATE devices UPDATE devices
SET alarm_count = alarm_count + 1, SET alarm_count = alarm_count + 1,
@ -51,7 +51,7 @@ def increment_alarm_count_by_ip(client_ip: str) -> bool:
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"更新报警次数失败{str(e)}") from e raise Exception(f"更新报警次数失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -99,7 +99,7 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"更新设备在线状态失败{str(e)}") from e raise Exception(f"更新设备在线状态失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -124,7 +124,7 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): #
# 返回信息 # 返回信息
return APIResponse( return APIResponse(
code=200, code=200,
message=f"设备IP {device_data.ip} 已存在返回已有设备信息", message=f"设备IP {device_data.ip} 已存在返回已有设备信息",
data=DeviceResponse(** existing_device) data=DeviceResponse(** existing_device)
) )
@ -173,9 +173,9 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): #
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"创建设备失败{str(e)}") from e raise Exception(f"创建设备失败: {str(e)}") from e
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise Exception(f"设备详细信息JSON序列化失败{str(e)}") from e raise Exception(f"设备详细信息JSON序列化失败: {str(e)}") from e
except Exception as e: except Exception as e:
if conn: if conn:
conn.rollback() conn.rollback()
@ -185,8 +185,8 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): #
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)") @router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
async def get_device_list( async def get_device_list(
page: int = Query(1, ge=1, description="页码默认第1页"), page: int = Query(1, ge=1, description="页码默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100之间"), page_size: int = Query(10, ge=1, le=100, description="每页条数1-100之间"),
device_type: str = Query(None, description="按设备类型筛选"), device_type: str = Query(None, description="按设备类型筛选"),
online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选") online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选")
): ):
@ -233,6 +233,6 @@ async def get_device_list(
) )
except MySQLError as e: except MySQLError as e:
raise Exception(f"获取设备列表失败{str(e)}") from e raise Exception(f"获取设备列表失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)

View File

@ -18,27 +18,27 @@ router = APIRouter(
# ------------------------------ # ------------------------------
# 1. 创建人脸记录(核心修正ID 数据库自增前端无需传) # 1. 创建人脸记录(核心修正: ID 数据库自增前端无需传)
# ------------------------------ # ------------------------------
@router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件ID自增") @router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件ID自增")
async def create_face( async def create_face(
# 前端仅需传name可选Form格式、file必传文件) # 前端仅需传: name可选Form格式、file必传文件)
name: str = Form(None, max_length=255, description="名称(可选)"), name: str = Form(None, max_length=255, description="名称(可选)"),
file: UploadFile = File(..., description="人脸文件(必传暂不处理内容)") file: UploadFile = File(..., description="人脸文件(必传暂不处理内容)")
): ):
""" """
创建人脸记录 创建人脸记录:
- 需登录认证 - 需登录认证
- 前端传参multipart/form-data 表单name 可选file 必传) - 前端传参: multipart/form-data 表单name 可选file 必传)
- ID 由数据库自动生成无需前端传入 - ID 由数据库自动生成无需前端传入
- 暂不处理文件内容eigenvalue 设为 None - 暂不处理文件内容eigenvalue 设为 None
""" """
# 调用你的方法 # 调用你的方法
conn = None conn = None
cursor = None cursor = None
try: try:
# 1. 用模型校验 name仅校验长度无需ID # 1. 用模型校验 name仅校验长度无需ID
face_create = FaceCreateRequest(name=name) face_create = FaceCreateRequest(name=name)
conn = db.get_connection() conn = db.get_connection()
@ -57,9 +57,9 @@ async def create_face(
) )
# 打印数组长度 # 打印数组长度
print(f"文件大小{len(file_content)} 字节") print(f"文件大小: {len(file_content)} 字节")
# 2. 插入数据库无需传 ID自增只传 name 和 eigenvalueNone # 2. 插入数据库: 无需传 ID自增只传 name 和 eigenvalueNone
insert_query = """ insert_query = """
INSERT INTO face (name, eigenvalue) INSERT INTO face (name, eigenvalue)
VALUES (%s, %s) VALUES (%s, %s)
@ -67,7 +67,7 @@ async def create_face(
cursor.execute(insert_query, (face_create.name, str(eigenvalue))) cursor.execute(insert_query, (face_create.name, str(eigenvalue)))
conn.commit() conn.commit()
# 3. 获取数据库自动生成的 ID关键用 LAST_INSERT_ID() 查刚插入的记录) # 3. 获取数据库自动生成的 ID关键: 用 LAST_INSERT_ID() 查刚插入的记录)
select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()" select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()"
cursor.execute(select_new_query) cursor.execute(select_new_query)
created_face = cursor.fetchone() created_face = cursor.fetchone()
@ -75,12 +75,12 @@ async def create_face(
if not created_face: if not created_face:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail="创建人脸记录成功但无法获取新创建的记录" detail="创建人脸记录成功但无法获取新创建的记录"
) )
return APIResponse( return APIResponse(
code=201, code=201,
message=f"人脸记录创建成功ID{created_face['id']}文件名{file.filename}", message=f"人脸记录创建成功ID: {created_face['id']}文件名: {file.filename}",
data=FaceResponse(** created_face) data=FaceResponse(** created_face)
) )
except MySQLError as e: except MySQLError as e:
@ -89,13 +89,13 @@ async def create_face(
# 改为使用HTTPException # 改为使用HTTPException
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"创建人脸记录失败{str(e)}" detail=f"创建人脸记录失败: {str(e)}"
) from e ) from e
except Exception as e: except Exception as e:
# 捕获其他可能的异常 # 捕获其他可能的异常
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"服务器错误{str(e)}" detail=f"服务器错误: {str(e)}"
) from e ) from e
finally: finally:
await file.close() # 关闭文件流 await file.close() # 关闭文件流
@ -113,11 +113,11 @@ async def create_face(
eigenvalue = str(eigenvalue) eigenvalue = str(eigenvalue)
# ------------------------------ # ------------------------------
# 2. 获取单个人脸记录(不变用自增ID查询 # 2. 获取单个人脸记录(不变用自增ID查询
# ------------------------------ # ------------------------------
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录") @router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
async def get_face( async def get_face(
face_id: int, # 这里的 ID 是数据库自增的前端从创建响应中获取 face_id: int, # 这里的 ID 是数据库自增的前端从创建响应中获取
current_user: UserResponse = Depends(get_current_user) current_user: UserResponse = Depends(get_current_user)
): ):
conn = None conn = None
@ -145,7 +145,7 @@ async def get_face(
# 改为使用HTTPException # 改为使用HTTPException
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"查询人脸记录失败{str(e)}" detail=f"查询人脸记录失败: {str(e)}"
) from e ) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -176,14 +176,14 @@ async def get_all_faces(
except MySQLError as e: except MySQLError as e:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"查询所有人脸记录失败{str(e)}" detail=f"查询所有人脸记录失败: {str(e)}"
) from e ) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# ------------------------------ # ------------------------------
# 4. 更新人脸记录(不变用自增ID更新 # 4. 更新人脸记录(不变用自增ID更新
# ------------------------------ # ------------------------------
@router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录") @router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录")
async def update_face( async def update_face(
@ -240,14 +240,14 @@ async def update_face(
conn.rollback() conn.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"更新人脸记录失败{str(e)}" detail=f"更新人脸记录失败: {str(e)}"
) from e ) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# ------------------------------ # ------------------------------
# 5. 删除人脸记录(不变用自增ID删除 # 5. 删除人脸记录(不变用自增ID删除
# ------------------------------ # ------------------------------
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录") @router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
async def delete_face( async def delete_face(
@ -283,7 +283,7 @@ async def delete_face(
conn.rollback() conn.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"删除人脸记录失败{str(e)}" detail=f"删除人脸记录失败: {str(e)}"
) from e ) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -291,10 +291,10 @@ async def delete_face(
def get_all_face_name_with_eigenvalue() -> dict: def get_all_face_name_with_eigenvalue() -> dict:
""" """
获取所有人脸的名称及其对应的特征值组成字典返回 获取所有人脸的名称及其对应的特征值组成字典返回
key: 人脸名称name key: 人脸名称name
value: 人脸特征值eigenvalue若名称重复则返回平均特征值 value: 人脸特征值eigenvalue若名称重复则返回平均特征值
过滤掉name为None的记录避免字典key为None的情况 : 过滤掉name为None的记录避免字典key为None的情况
""" """
conn = None conn = None
cursor = None cursor = None
@ -303,27 +303,27 @@ def get_all_face_name_with_eigenvalue() -> dict:
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 2. 执行SQL查询只获取name非空的记录减少数据传输 # 2. 执行SQL查询: 只获取name非空的记录减少数据传输
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL" query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
cursor.execute(query) cursor.execute(query)
faces = cursor.fetchall() # 返回结果列表套字典如 [{"name":"张三","eigenvalue":...}, ...] faces = cursor.fetchall() # 返回结果: 列表套字典如 [{"name":"张三","eigenvalue":...}, ...]
# 3. 收集同一名称对应的所有特征值(处理名称重复场景) # 3. 收集同一名称对应的所有特征值(处理名称重复场景)
name_to_eigenvalues = {} name_to_eigenvalues = {}
for face in faces: for face in faces:
name = face["name"] name = face["name"]
eigenvalue = face["eigenvalue"] eigenvalue = face["eigenvalue"]
# 若名称已存在追加特征值;否则新建列表存储 # 若名称已存在追加特征值;否则新建列表存储
if name in name_to_eigenvalues: if name in name_to_eigenvalues:
name_to_eigenvalues[name].append(eigenvalue) name_to_eigenvalues[name].append(eigenvalue)
else: else:
name_to_eigenvalues[name] = [eigenvalue] name_to_eigenvalues[name] = [eigenvalue]
# 4. 构建最终字典重复名称取平均唯一名称直接取特征值 # 4. 构建最终字典: 重复名称取平均唯一名称直接取特征值
face_dict = {} face_dict = {}
for name, eigenvalues in name_to_eigenvalues.items(): for name, eigenvalues in name_to_eigenvalues.items():
# 处理特征值多个则求平均单个则直接使用 # 处理特征值: 多个则求平均单个则直接使用
if len(eigenvalues) > 1: if len(eigenvalues) > 1:
# 调用外部方法计算平均特征值需确保binary_face_feature_handler已正确导入 # 调用外部方法计算平均特征值需确保binary_face_feature_handler已正确导入
face_dict[name] = get_average_feature(eigenvalues) face_dict[name] = get_average_feature(eigenvalues)
@ -334,8 +334,8 @@ def get_all_face_name_with_eigenvalue() -> dict:
return face_dict return face_dict
except MySQLError as e: except MySQLError as e:
# 捕获数据库异常添加上下文信息后重新抛出(便于定位问题) # 捕获数据库异常添加上下文信息后重新抛出(便于定位问题)
raise Exception(f"获取人脸名称与特征值失败{str(e)}") from e raise Exception(f"获取人脸名称与特征值失败: {str(e)}") from e
finally: finally:
# 5. 无论是否异常均释放数据库连接和游标(避免资源泄漏) # 5. 无论是否异常均释放数据库连接和游标(避免资源泄漏)
db.close_connection(conn, cursor) db.close_connection(conn, cursor)

View File

@ -21,7 +21,7 @@ router = APIRouter(
async def create_sensitive( async def create_sensitive(
sensitive: SensitiveCreateRequest): # 添加了登录认证依赖 sensitive: SensitiveCreateRequest): # 添加了登录认证依赖
""" """
创建敏感信息记录 创建敏感信息记录:
- 需登录认证 - 需登录认证
- 插入新的敏感信息记录到数据库ID由数据库自动生成 - 插入新的敏感信息记录到数据库ID由数据库自动生成
- 返回创建成功信息 - 返回创建成功信息
@ -32,7 +32,7 @@ async def create_sensitive(
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 插入新敏感信息记录到数据库不包含ID由数据库自动生成) # 插入新敏感信息记录到数据库不包含ID由数据库自动生成)
insert_query = """ insert_query = """
INSERT INTO sensitives (name) INSERT INTO sensitives (name)
VALUES (%s) VALUES (%s)
@ -56,7 +56,7 @@ async def create_sensitive(
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"创建敏感信息记录失败{str(e)}") from e raise Exception(f"创建敏感信息记录失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -71,7 +71,7 @@ async def get_sensitive(
current_user: UserResponse = Depends(get_current_user) # 需登录认证 current_user: UserResponse = Depends(get_current_user) # 需登录认证
): ):
""" """
获取单个敏感信息记录 获取单个敏感信息记录:
- 需登录认证 - 需登录认证
- 根据ID查询敏感信息记录 - 根据ID查询敏感信息记录
- 返回查询到的敏感信息 - 返回查询到的敏感信息
@ -98,7 +98,7 @@ async def get_sensitive(
data=SensitiveResponse(**sensitive) data=SensitiveResponse(**sensitive)
) )
except MySQLError as e: except MySQLError as e:
raise Exception(f"查询敏感信息记录失败{str(e)}") from e raise Exception(f"查询敏感信息记录失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -109,7 +109,7 @@ async def get_sensitive(
@router.get("", response_model=APIResponse, summary="获取所有敏感信息记录") @router.get("", response_model=APIResponse, summary="获取所有敏感信息记录")
async def get_all_sensitives(): async def get_all_sensitives():
""" """
获取所有敏感信息记录 获取所有敏感信息记录:
- 需登录认证 - 需登录认证
- 查询所有敏感信息记录(不需要分页) - 查询所有敏感信息记录(不需要分页)
- 返回所有敏感信息列表 - 返回所有敏感信息列表
@ -130,7 +130,7 @@ async def get_all_sensitives():
data=[SensitiveResponse(**sensitive) for sensitive in sensitives] data=[SensitiveResponse(**sensitive) for sensitive in sensitives]
) )
except MySQLError as e: except MySQLError as e:
raise Exception(f"查询所有敏感信息记录失败{str(e)}") from e raise Exception(f"查询所有敏感信息记录失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -145,7 +145,7 @@ async def update_sensitive(
current_user: UserResponse = Depends(get_current_user) # 需登录认证 current_user: UserResponse = Depends(get_current_user) # 需登录认证
): ):
""" """
更新敏感信息记录 更新敏感信息记录:
- 需登录认证 - 需登录认证
- 根据ID更新敏感信息记录 - 根据ID更新敏感信息记录
- 返回更新后的敏感信息 - 返回更新后的敏感信息
@ -203,7 +203,7 @@ async def update_sensitive(
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"更新敏感信息记录失败{str(e)}") from e raise Exception(f"更新敏感信息记录失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -217,7 +217,7 @@ async def delete_sensitive(
current_user: UserResponse = Depends(get_current_user) # 需登录认证 current_user: UserResponse = Depends(get_current_user) # 需登录认证
): ):
""" """
删除敏感信息记录 删除敏感信息记录:
- 需登录认证 - 需登录认证
- 根据ID删除敏感信息记录 - 根据ID删除敏感信息记录
- 返回删除成功信息 - 返回删除成功信息
@ -251,14 +251,14 @@ async def delete_sensitive(
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"删除敏感信息记录失败{str(e)}") from e raise Exception(f"删除敏感信息记录失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
def get_all_sensitive_words() -> list[str]: def get_all_sensitive_words() -> list[str]:
""" """
获取所有敏感词返回字符串数组 获取所有敏感词返回字符串数组
返回: 返回:
list[str]: 包含所有敏感词的数组 list[str]: 包含所有敏感词的数组
@ -273,7 +273,7 @@ def get_all_sensitive_words() -> list[str]:
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 执行查询只获取敏感词字段 # 执行查询只获取敏感词字段
query = "SELECT name FROM sensitives ORDER BY id" query = "SELECT name FROM sensitives ORDER BY id"
cursor.execute(query) cursor.execute(query)
sensitive_records = cursor.fetchall() sensitive_records = cursor.fetchall()
@ -283,7 +283,7 @@ def get_all_sensitive_words() -> list[str]:
except MySQLError as e: except MySQLError as e:
# 数据库错误处理 # 数据库错误处理
raise MySQLError(f"查询敏感词失败{str(e)}") from e raise MySQLError(f"查询敏感词失败: {str(e)}") from e
finally: finally:
# 确保资源正确释放 # 确保资源正确释放
db.close_connection(conn, cursor) db.close_connection(conn, cursor)

View File

@ -27,7 +27,7 @@ router = APIRouter(
@router.post("/register", response_model=APIResponse, summary="用户注册") @router.post("/register", response_model=APIResponse, summary="用户注册")
async def user_register(request: UserRegisterRequest): async def user_register(request: UserRegisterRequest):
""" """
用户注册 用户注册:
- 校验用户名是否已存在 - 校验用户名是否已存在
- 加密密码后插入数据库 - 加密密码后插入数据库
- 返回注册成功信息 - 返回注册成功信息
@ -67,7 +67,7 @@ async def user_register(request: UserRegisterRequest):
) )
except MySQLError as e: except MySQLError as e:
conn.rollback() # 数据库错误时回滚事务 conn.rollback() # 数据库错误时回滚事务
raise Exception(f"注册失败{str(e)}") from e raise Exception(f"注册失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -78,7 +78,7 @@ async def user_register(request: UserRegisterRequest):
@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token") @router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token")
async def user_login(request: UserLoginRequest): async def user_login(request: UserLoginRequest):
""" """
用户登录 用户登录:
- 校验用户名是否存在 - 校验用户名是否存在
- 校验密码是否正确 - 校验密码是否正确
- 生成 JWT Token 并返回 - 生成 JWT Token 并返回
@ -89,7 +89,7 @@ async def user_login(request: UserLoginRequest):
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 修复SQL查询添加 created_at 和 updated_at 字段 # 修复: SQL查询添加 created_at 和 updated_at 字段
query = """ query = """
SELECT id, username, password, created_at, updated_at SELECT id, username, password, created_at, updated_at
FROM users FROM users
@ -129,7 +129,7 @@ async def user_login(request: UserLoginRequest):
} }
) )
except MySQLError as e: except MySQLError as e:
raise Exception(f"登录失败{str(e)}") from e raise Exception(f"登录失败: {str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -142,8 +142,8 @@ async def get_current_user_info(
current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件 current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件
): ):
""" """
获取当前登录用户信息 获取当前登录用户信息:
- 需在请求头携带 Token格式Bearer <token> - 需在请求头携带 Token格式: Bearer <token>
- 认证通过后返回用户信息 - 认证通过后返回用户信息
""" """
return APIResponse( return APIResponse(

View File

@ -27,7 +27,7 @@ def init_insightface():
def add_binary_data(binary_data): def add_binary_data(binary_data):
""" """
接收单张图片的二进制数据提取特征并保存 接收单张图片的二进制数据提取特征并保存
参数: 参数:
binary_data: 图片的二进制数据bytes类型 binary_data: 图片的二进制数据bytes类型
@ -39,11 +39,11 @@ def add_binary_data(binary_data):
global _insightface_app, _feature_list global _insightface_app, _feature_list
if not _insightface_app: if not _insightface_app:
print("引擎未初始化无法处理") print("引擎未初始化无法处理")
return False, None return False, None
try: try:
# 直接处理二进制数据转换为图像格式 # 直接处理二进制数据: 转换为图像格式
img = Image.open(BytesIO(binary_data)) img = Image.open(BytesIO(binary_data))
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
@ -70,15 +70,15 @@ def get_average_feature(features=None):
计算多个特征向量的平均值 计算多个特征向量的平均值
参数: 参数:
features: 可选特征值列表。如果未提供则使用全局存储的_feature_list features: 可选特征值列表。如果未提供则使用全局存储的_feature_list
每个元素可以是字符串格式或numpy数组 每个元素可以是字符串格式或numpy数组
返回: 返回:
单一平均特征向量的numpy数组若无可计算数据则返回None 单一平均特征向量的numpy数组若无可计算数据则返回None
""" """
global _feature_list global _feature_list
# 如果未提供features参数则使用全局特征列表 # 如果未提供features参数则使用全局特征列表
if features is None: if features is None:
features = _feature_list features = _feature_list
@ -105,7 +105,7 @@ def get_average_feature(features=None):
processed_features.append(embedding_np) processed_features.append(embedding_np)
print(f"已添加第 {i + 1} 个特征值用于计算平均值") print(f"已添加第 {i + 1} 个特征值用于计算平均值")
else: else:
print(f"跳过第 {i + 1} 个特征值不是一维数组") print(f"跳过第 {i + 1} 个特征值不是一维数组")
except Exception as e: except Exception as e:
print(f"处理第 {i + 1} 个特征值时出错: {e}") print(f"处理第 {i + 1} 个特征值时出错: {e}")
@ -118,12 +118,12 @@ def get_average_feature(features=None):
# 检查所有特征向量维度是否相同 # 检查所有特征向量维度是否相同
dims = {feat.shape[0] for feat in processed_features} dims = {feat.shape[0] for feat in processed_features}
if len(dims) > 1: if len(dims) > 1:
print(f"特征值维度不一致无法计算平均值。检测到的维度: {dims}") print(f"特征值维度不一致无法计算平均值。检测到的维度: {dims}")
return None return None
# 计算平均值 # 计算平均值
avg_feature = np.mean(processed_features, axis=0) avg_feature = np.mean(processed_features, axis=0)
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量维度: {avg_feature.shape[0]}") print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量维度: {avg_feature.shape[0]}")
return avg_feature return avg_feature

View File

@ -21,7 +21,7 @@ WS_ENDPOINT = "/ws" # WebSocket端点路径
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制 FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
# 工具函数获取格式化时间字符串(统一时间戳格式) # 工具函数: 获取格式化时间字符串(统一时间戳格式)
def get_current_time_str() -> str: def get_current_time_str() -> str:
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@ -65,32 +65,32 @@ class ClientConnection:
"client_ip": self.client_ip "client_ip": self.client_ip
} }
await self.websocket.send_json(frame_permit_msg) await self.websocket.send_json(frame_permit_msg)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}已发送帧发送许可信号(取帧后立即通知)") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送帧发送许可信号(取帧后立即通知)")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}帧许可信号发送失败 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可信号发送失败 - {str(e)}")
async def consume_frames(self) -> None: async def consume_frames(self) -> None:
"""消费队列中的帧并处理(核心调整取帧后立即发许可再处理帧)""" """消费队列中的帧并处理(核心调整: 取帧后立即发许可再处理帧)"""
try: try:
while True: while True:
# 1. 从队列取出帧(阻塞直到有帧可用) # 1. 从队列取出帧(阻塞直到有帧可用)
frame_data = await self.frame_queue.get() frame_data = await self.frame_queue.get()
# -------------------------- 核心修改取出帧后立即发送下一帧许可 -------------------------- # -------------------------- 核心修改: 取出帧后立即发送下一帧许可 --------------------------
await self.send_frame_permit() # 取帧即通知客户端发下一帧无需等处理完成 await self.send_frame_permit() # 取帧即通知客户端发下一帧无需等处理完成
# ----------------------------------------------------------------------------------------- # -----------------------------------------------------------------------------------------
try: try:
# 2. 处理取出的帧(即使处理慢客户端也已收到许可可提前准备下一帧) # 2. 处理取出的帧(即使处理慢客户端也已收到许可可提前准备下一帧)
await self.process_frame(frame_data) await self.process_frame(frame_data)
finally: finally:
# 3. 标记帧任务完成(无论处理成功/失败都需清理队列) # 3. 标记帧任务完成(无论处理成功/失败都需清理队列)
self.frame_queue.task_done() self.frame_queue.task_done()
except asyncio.CancelledError: except asyncio.CancelledError:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}帧消费任务已取消") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费任务已取消")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}帧消费逻辑错误 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}")
async def process_frame(self, frame_data: bytes) -> None: async def process_frame(self, frame_data: bytes) -> None:
"""处理单帧图像数据(检测违规后发送危险通知 + 调用违规次数加一方法)""" """处理单帧图像数据(检测违规后发送危险通知 + 调用违规次数加一方法)"""
@ -98,31 +98,31 @@ class ClientConnection:
nparr = np.frombuffer(frame_data, np.uint8) nparr = np.frombuffer(frame_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None: if img is None:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}无法解析图像数据") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析图像数据")
return return
# 确保图像保存目录存在 # 确保图像保存目录存在
os.makedirs('images', exist_ok=True) os.makedirs('images', exist_ok=True)
# 保存图像按IP+时间戳命名避免冲突) # 保存图像按IP+时间戳命名避免冲突)
filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg" filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg"
try: try:
cv2.imwrite(filename, img) cv2.imwrite(filename, img)
print(f"[{get_current_time_str()}] 图像已保存至{filename}") print(f"[{get_current_time_str()}] 图像已保存至: {filename}")
has_violation, data, type = detect(img) has_violation, data, type = detect(img)
print(has_violation) print(has_violation)
print(type) print(type)
print(data) print(data)
if has_violation: if has_violation:
print( print(
f"[{get_current_time_str()}] 客户端{self.client_ip}检测到违规 - 类型: {type}, 详情: {data}") f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - 类型: {type}, 详情: {data}")
# 调用违规次数加一方法 # 调用违规次数加一方法
try: try:
await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip) await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
print(f"[{get_current_time_str()}] 客户端{self.client_ip}违规次数已+1") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数已+1")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}违规次数更新失败 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}")
# 发送「危险通知」 # 发送「危险通知」
danger_msg = { danger_msg = {
@ -132,9 +132,9 @@ class ClientConnection:
} }
await self.websocket.send_json(danger_msg) await self.websocket.send_json(danger_msg)
else: else:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}未检测到违规") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 未检测到违规")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{self.client_ip}图像处理错误 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}")
# 全局状态管理 # 全局状态管理
@ -149,7 +149,7 @@ async def heartbeat_checker():
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()] timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
if timeout_ips: if timeout_ips:
print(f"[{current_time}] 心跳检查{len(timeout_ips)}个客户端超时IP{timeout_ips}") print(f"[{current_time}] 心跳检查: {len(timeout_ips)}个客户端超时IP: {timeout_ips}")
for ip in timeout_ips: for ip in timeout_ips:
try: try:
conn = connected_clients[ip] conn = connected_clients[ip]
@ -162,13 +162,13 @@ async def heartbeat_checker():
await asyncio.to_thread(update_online_status_by_ip, ip, 0) await asyncio.to_thread(update_online_status_by_ip, ip, 0)
action_data = DeviceActionCreate(client_ip=ip, action=0) action_data = DeviceActionCreate(client_ip=ip, action=0)
await asyncio.to_thread(add_device_action, action_data) await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{ip}已标记为离线并记录操作") print(f"[{current_time}] 客户端{ip}: 已标记为离线并记录操作")
except Exception as e: except Exception as e:
print(f"[{current_time}] 客户端{ip}离线状态更新失败 - {str(e)}") print(f"[{current_time}] 客户端{ip}: 离线状态更新失败 - {str(e)}")
finally: finally:
connected_clients.pop(ip, None) connected_clients.pop(ip, None)
else: else:
print(f"[{current_time}] 心跳检查{len(connected_clients)}个客户端在线") print(f"[{current_time}] 心跳检查: {len(connected_clients)}个客户端在线")
await asyncio.sleep(HEARTBEAT_INTERVAL) await asyncio.sleep(HEARTBEAT_INTERVAL)
@ -178,7 +178,7 @@ async def heartbeat_checker():
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
global heartbeat_task global heartbeat_task
heartbeat_task = asyncio.create_task(heartbeat_checker()) heartbeat_task = asyncio.create_task(heartbeat_checker())
print(f"[{get_current_time_str()}] 全局心跳检查任务启动任务ID{id(heartbeat_task)}") print(f"[{get_current_time_str()}] 全局心跳检查任务启动任务ID: {id(heartbeat_task)}")
yield yield
if heartbeat_task and not heartbeat_task.done(): if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel() heartbeat_task.cancel()
@ -198,11 +198,11 @@ async def send_heartbeat_ack(conn: ClientConnection):
"client_ip": conn.client_ip "client_ip": conn.client_ip
} }
await conn.websocket.send_json(heartbeat_ack_msg) await conn.websocket.send_json(heartbeat_ack_msg)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}已发送心跳确认") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送心跳确认")
return True return True
except Exception as e: except Exception as e:
connected_clients.pop(conn.client_ip, None) connected_clients.pop(conn.client_ip, None)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}心跳确认发送失败 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 心跳确认发送失败 - {str(e)}")
return False return False
@ -213,17 +213,17 @@ async def handle_text_msg(conn: ClientConnection, text: str):
conn.update_heartbeat() conn.update_heartbeat()
await send_heartbeat_ack(conn) await send_heartbeat_ack(conn)
else: else:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}未知文本消息类型({msg.get('type')}") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知文本消息类型({msg.get('type')}")
except json.JSONDecodeError: except json.JSONDecodeError:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}无效JSON文本消息") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON文本消息")
async def handle_binary_msg(conn: ClientConnection, data: bytes): async def handle_binary_msg(conn: ClientConnection, data: bytes):
try: try:
conn.frame_queue.put_nowait(data) conn.frame_queue.put_nowait(data)
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}图像数据({len(data)}字节)已加入队列") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 图像数据({len(data)}字节)已加入队列")
except asyncio.QueueFull: except asyncio.QueueFull:
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}帧队列已满丢弃当前图像数据") print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 帧队列已满丢弃当前图像数据")
# WebSocket路由配置 # WebSocket路由配置
@ -237,7 +237,7 @@ async def websocket_endpoint(websocket: WebSocket):
await websocket.accept() await websocket.accept()
client_ip = websocket.client.host if websocket.client else "unknown_ip" client_ip = websocket.client.host if websocket.client else "unknown_ip"
current_time = get_current_time_str() current_time = get_current_time_str()
print(f"[{current_time}] 客户端{client_ip}WebSocket连接已建立") print(f"[{current_time}] 客户端{client_ip}: WebSocket连接已建立")
is_online_updated = False is_online_updated = False
@ -249,13 +249,13 @@ async def websocket_endpoint(websocket: WebSocket):
old_conn.consumer_task.cancel() old_conn.consumer_task.cancel()
await old_conn.websocket.close(code=1008, reason="同一IP新连接建立") await old_conn.websocket.close(code=1008, reason="同一IP新连接建立")
connected_clients.pop(client_ip) connected_clients.pop(client_ip)
print(f"[{current_time}] 客户端{client_ip}已关闭旧连接") print(f"[{current_time}] 客户端{client_ip}: 已关闭旧连接")
# 注册新连接 # 注册新连接
new_conn = ClientConnection(websocket, client_ip) new_conn = ClientConnection(websocket, client_ip)
connected_clients[client_ip] = new_conn connected_clients[client_ip] = new_conn
new_conn.start_consumer() new_conn.start_consumer()
# 初始许可连接建立后立即发一次让客户端知道可发第一帧(后续靠取帧后自动发) # 初始许可: 连接建立后立即发一次让客户端知道可发第一帧(后续靠取帧后自动发)
await new_conn.send_frame_permit() await new_conn.send_frame_permit()
# 标记上线并记录 # 标记上线并记录
@ -263,12 +263,12 @@ async def websocket_endpoint(websocket: WebSocket):
await asyncio.to_thread(update_online_status_by_ip, client_ip, 1) await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
action_data = DeviceActionCreate(client_ip=client_ip, action=1) action_data = DeviceActionCreate(client_ip=client_ip, action=1)
await asyncio.to_thread(add_device_action, action_data) await asyncio.to_thread(add_device_action, action_data)
print(f"[{current_time}] 客户端{client_ip}已标记为在线并记录操作") print(f"[{current_time}] 客户端{client_ip}: 已标记为在线并记录操作")
is_online_updated = True is_online_updated = True
except Exception as e: except Exception as e:
print(f"[{current_time}] 客户端{client_ip}上线状态更新失败 - {str(e)}") print(f"[{current_time}] 客户端{client_ip}: 上线状态更新失败 - {str(e)}")
print(f"[{current_time}] 客户端{client_ip}新连接注册成功在线数{len(connected_clients)}") print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功在线数: {len(connected_clients)}")
# 消息循环 # 消息循环
while True: while True:
@ -279,9 +279,9 @@ async def websocket_endpoint(websocket: WebSocket):
await handle_binary_msg(new_conn, data["bytes"]) await handle_binary_msg(new_conn, data["bytes"])
except WebSocketDisconnect as e: except WebSocketDisconnect as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}主动断开连接(代码{e.code}") print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code}")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}连接异常 - {str(e)[:50]}") print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}")
finally: finally:
# 清理资源并标记离线 # 清理资源并标记离线
if client_ip in connected_clients: if client_ip in connected_clients:
@ -295,9 +295,9 @@ async def websocket_endpoint(websocket: WebSocket):
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0) await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
action_data = DeviceActionCreate(client_ip=client_ip, action=0) action_data = DeviceActionCreate(client_ip=client_ip, action=0)
await asyncio.to_thread(add_device_action, action_data) await asyncio.to_thread(add_device_action, action_data)
print(f"[{get_current_time_str()}] 客户端{client_ip}断开后已标记为离线") print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后已标记为离线")
except Exception as e: except Exception as e:
print(f"[{get_current_time_str()}] 客户端{client_ip}断开后离线更新失败 - {str(e)}") print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后离线更新失败 - {str(e)}")
connected_clients.pop(client_ip, None) connected_clients.pop(client_ip, None)
print(f"[{get_current_time_str()}] 客户端{client_ip}资源已清理在线数{len(connected_clients)}") print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理在线数: {len(connected_clients)}")