优化代码风格
This commit is contained in:
@ -12,8 +12,4 @@ charset = utf8mb4
|
|||||||
[jwt]
|
[jwt]
|
||||||
secret_key = 6tsieyd87wefdw2wgeduwte23rfcsd
|
secret_key = 6tsieyd87wefdw2wgeduwte23rfcsd
|
||||||
algorithm = HS256
|
algorithm = HS256
|
||||||
access_token_expire_minutes = 30
|
access_token_expire_minutes = 30
|
||||||
|
|
||||||
[live]
|
|
||||||
rtmp_url = rtmp://192.168.110.25:1935/live/
|
|
||||||
webrtc_url = http://192.168.110.25:1985/rtc/v1/whep/?app=live&stream=
|
|
18
core/all.py
18
core/all.py
@ -22,7 +22,7 @@ _task_counter_lock = threading.Lock() # 任务计数锁
|
|||||||
|
|
||||||
# -------------------------- 工具函数 --------------------------
|
# -------------------------- 工具函数 --------------------------
|
||||||
def _get_next_task_id():
|
def _get_next_task_id():
|
||||||
"""获取唯一任务ID,用于日志追踪"""
|
"""获取唯一任务ID、用于日志追踪"""
|
||||||
global _task_counter
|
global _task_counter
|
||||||
with _task_counter_lock:
|
with _task_counter_lock:
|
||||||
_task_counter += 1
|
_task_counter += 1
|
||||||
@ -65,11 +65,11 @@ def _init_thread_pool():
|
|||||||
max_workers=MAX_WORKERS,
|
max_workers=MAX_WORKERS,
|
||||||
thread_name_prefix="DetectionThread"
|
thread_name_prefix="DetectionThread"
|
||||||
)
|
)
|
||||||
print(f"=== 线程池初始化完成,最大线程数: {MAX_WORKERS} ===")
|
print(f"=== 线程池初始化完成、最大线程数: {MAX_WORKERS} ===")
|
||||||
|
|
||||||
|
|
||||||
def shutdown():
|
def shutdown():
|
||||||
"""关闭线程池,释放资源"""
|
"""关闭线程池、释放资源"""
|
||||||
global _executor
|
global _executor
|
||||||
with _executor_lock:
|
with _executor_lock:
|
||||||
if _executor is not None:
|
if _executor is not None:
|
||||||
@ -82,7 +82,7 @@ def shutdown():
|
|||||||
def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple:
|
def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple:
|
||||||
"""在子线程中执行检测逻辑"""
|
"""在子线程中执行检测逻辑"""
|
||||||
thread_name = threading.current_thread().name
|
thread_name = threading.current_thread().name
|
||||||
print(f"任务[{task_id}] 开始执行,线程: {thread_name}")
|
print(f"任务[{task_id}] 开始执行、线程: {thread_name}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 按照优先级执行检测
|
# 按照优先级执行检测
|
||||||
@ -98,7 +98,7 @@ def _detect_in_thread(frame: np.ndarray, task_id: int) -> tuple:
|
|||||||
|
|
||||||
print(f"任务[{task_id}] {detector}检测结果: {'成功' if flag else '失败'}")
|
print(f"任务[{task_id}] {detector}检测结果: {'成功' if flag else '失败'}")
|
||||||
if flag:
|
if flag:
|
||||||
print(f"任务[{task_id}] 完成检测,使用检测器: {detector}")
|
print(f"任务[{task_id}] 完成检测、使用检测器: {detector}")
|
||||||
return (True, result, detector, task_id)
|
return (True, result, detector, task_id)
|
||||||
|
|
||||||
# 所有检测器均未检测到结果
|
# 所有检测器均未检测到结果
|
||||||
@ -116,14 +116,14 @@ def detect(frame: np.ndarray) -> Future:
|
|||||||
提交检测任务到线程池
|
提交检测任务到线程池
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
frame: 待检测图像(ndarray格式,cv2.imdecode生成)
|
frame: 待检测图像(ndarray格式、cv2.imdecode生成)
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
Future对象,通过result()方法获取检测结果
|
Future对象、通过result()方法获取检测结果
|
||||||
"""
|
"""
|
||||||
# 确保模型已加载
|
# 确保模型已加载
|
||||||
if not _model_loaded:
|
if not _model_loaded:
|
||||||
print("警告: 模型尚未加载,将自动加载")
|
print("警告: 模型尚未加载、将自动加载")
|
||||||
load_model()
|
load_model()
|
||||||
|
|
||||||
# 生成任务ID
|
# 生成任务ID
|
||||||
@ -131,6 +131,6 @@ def detect(frame: np.ndarray) -> Future:
|
|||||||
|
|
||||||
# 提交任务到线程池
|
# 提交任务到线程池
|
||||||
future = _executor.submit(_detect_in_thread, frame, task_id)
|
future = _executor.submit(_detect_in_thread, frame, task_id)
|
||||||
print(f"任务[{task_id}]: 已提交到线程池")
|
print(f"任务[{task_id}]: 已提交到线程池")
|
||||||
return future
|
return future
|
||||||
|
|
||||||
|
42
core/face.py
42
core/face.py
@ -16,7 +16,7 @@ try:
|
|||||||
pynvml.nvmlInit()
|
pynvml.nvmlInit()
|
||||||
_nvml_available = True
|
_nvml_available = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("警告: pynvml库未安装,无法检测GPU状态,将默认使用0号GPU")
|
print("警告: pynvml库未安装、无法检测GPU状态、将默认使用0号GPU")
|
||||||
_nvml_available = False
|
_nvml_available = False
|
||||||
|
|
||||||
# 全局变量
|
# 全局变量
|
||||||
@ -58,7 +58,7 @@ def check_gpu_availability(gpu_id, threshold=0.7):
|
|||||||
|
|
||||||
|
|
||||||
def select_best_gpu(preferred_gpus=[0, 1]):
|
def select_best_gpu(preferred_gpus=[0, 1]):
|
||||||
"""选择最佳可用GPU,严格按照首选列表顺序检查,优先使用0号GPU"""
|
"""选择最佳可用GPU、严格按照首选列表顺序检查、优先使用0号GPU"""
|
||||||
# 首先检查首选GPU列表
|
# 首先检查首选GPU列表
|
||||||
for gpu_id in preferred_gpus:
|
for gpu_id in preferred_gpus:
|
||||||
try:
|
try:
|
||||||
@ -68,17 +68,17 @@ def select_best_gpu(preferred_gpus=[0, 1]):
|
|||||||
|
|
||||||
# 检查GPU是否可用
|
# 检查GPU是否可用
|
||||||
if check_gpu_availability(gpu_id):
|
if check_gpu_availability(gpu_id):
|
||||||
print(f"GPU {gpu_id} 可用,将使用该GPU")
|
print(f"GPU {gpu_id} 可用、将使用该GPU")
|
||||||
return gpu_id
|
return gpu_id
|
||||||
else:
|
else:
|
||||||
if gpu_id == 0:
|
if gpu_id == 0:
|
||||||
print(f"GPU 0 内存使用率过高(繁忙),尝试切换到其他GPU")
|
print(f"GPU 0 内存使用率过高(繁忙)、尝试切换到其他GPU")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"GPU {gpu_id} 不存在或无法访问: {e}")
|
print(f"GPU {gpu_id} 不存在或无法访问: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 如果所有首选GPU都不可用,返回-1表示使用CPU
|
# 如果所有首选GPU都不可用、返回-1表示使用CPU
|
||||||
print("所有指定的GPU都不可用,将使用CPU进行计算")
|
print("所有指定的GPU都不可用、将使用CPU进行计算")
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
|
|
||||||
@ -122,12 +122,12 @@ def _release_engine():
|
|||||||
|
|
||||||
|
|
||||||
def _monitor_thread():
|
def _monitor_thread():
|
||||||
"""监控线程,检查并释放超时未使用的资源"""
|
"""监控线程、检查并释放超时未使用的资源"""
|
||||||
global _ref_count, _last_used_time, _face_app
|
global _ref_count, _last_used_time, _face_app
|
||||||
while True:
|
while True:
|
||||||
time.sleep(5) # 每5秒检查一次
|
time.sleep(5) # 每5秒检查一次
|
||||||
with _lock:
|
with _lock:
|
||||||
# 只有当引擎存在、没有引用且超时,才释放
|
# 只有当引擎存在、没有引用且超时、才释放
|
||||||
if _face_app and _ref_count == 0 and not _is_releasing:
|
if _face_app and _ref_count == 0 and not _is_releasing:
|
||||||
elapsed = time.time() - _last_used_time
|
elapsed = time.time() - _last_used_time
|
||||||
if elapsed > _release_timeout:
|
if elapsed > _release_timeout:
|
||||||
@ -136,7 +136,7 @@ def _monitor_thread():
|
|||||||
|
|
||||||
|
|
||||||
def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
|
def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
|
||||||
"""加载人脸识别模型及已知人脸特征库,默认优先使用0号GPU"""
|
"""加载人脸识别模型及已知人脸特征库、默认优先使用0号GPU"""
|
||||||
global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id
|
global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id
|
||||||
|
|
||||||
# 确保监控线程只启动一次
|
# 确保监控线程只启动一次
|
||||||
@ -144,11 +144,11 @@ def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
|
|||||||
threading.Thread(target=_monitor_thread, daemon=True, name="FaceMonitor").start()
|
threading.Thread(target=_monitor_thread, daemon=True, name="FaceMonitor").start()
|
||||||
print("Face monitor thread started")
|
print("Face monitor thread started")
|
||||||
|
|
||||||
# 如果正在释放中,等待释放完成
|
# 如果正在释放中、等待释放完成
|
||||||
while _is_releasing:
|
while _is_releasing:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
# 如果已经初始化,直接返回
|
# 如果已经初始化、直接返回
|
||||||
if _face_app:
|
if _face_app:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -158,7 +158,7 @@ def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
|
|||||||
print("正在初始化InsightFace人脸识别引擎...")
|
print("正在初始化InsightFace人脸识别引擎...")
|
||||||
_face_app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
|
_face_app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
|
||||||
|
|
||||||
# 选择合适的GPU,默认优先使用0号
|
# 选择合适的GPU、默认优先使用0号
|
||||||
ctx_id = 0
|
ctx_id = 0
|
||||||
if prefer_gpu:
|
if prefer_gpu:
|
||||||
ctx_id = select_best_gpu(preferred_gpus)
|
ctx_id = select_best_gpu(preferred_gpus)
|
||||||
@ -166,9 +166,9 @@ def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
|
|||||||
_used_gpu_id = ctx_id if _using_gpu else -1
|
_used_gpu_id = ctx_id if _using_gpu else -1
|
||||||
|
|
||||||
if _using_gpu:
|
if _using_gpu:
|
||||||
print(f"成功初始化,使用GPU {ctx_id} 进行计算")
|
print(f"成功初始化、使用GPU {ctx_id} 进行计算")
|
||||||
else:
|
else:
|
||||||
print("成功初始化,使用CPU进行计算")
|
print("成功初始化、使用CPU进行计算")
|
||||||
|
|
||||||
# 准备模型
|
# 准备模型
|
||||||
_face_app.prepare(ctx_id=ctx_id, det_size=(640, 640))
|
_face_app.prepare(ctx_id=ctx_id, det_size=(640, 640))
|
||||||
@ -188,10 +188,10 @@ def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
|
|||||||
for person_name, eigenvalue_data in face_data.items():
|
for person_name, eigenvalue_data in face_data.items():
|
||||||
# 处理特征值数据 - 兼容数组和字符串两种格式
|
# 处理特征值数据 - 兼容数组和字符串两种格式
|
||||||
if isinstance(eigenvalue_data, np.ndarray):
|
if isinstance(eigenvalue_data, np.ndarray):
|
||||||
# 如果已经是numpy数组,直接使用
|
# 如果已经是numpy数组、直接使用
|
||||||
eigenvalue = eigenvalue_data.astype(np.float32)
|
eigenvalue = eigenvalue_data.astype(np.float32)
|
||||||
elif isinstance(eigenvalue_data, str):
|
elif isinstance(eigenvalue_data, str):
|
||||||
# 清理字符串:移除方括号、换行符和多余空格
|
# 清理字符串: 移除方括号、换行符和多余空格
|
||||||
cleaned = eigenvalue_data.replace('[', '').replace(']', '').replace('\n', '').strip()
|
cleaned = eigenvalue_data.replace('[', '').replace(']', '').replace('\n', '').strip()
|
||||||
# 按空格或逗号分割(处理可能的不同分隔符)
|
# 按空格或逗号分割(处理可能的不同分隔符)
|
||||||
values = [v for v in cleaned.split() if v]
|
values = [v for v in cleaned.split() if v]
|
||||||
@ -217,7 +217,7 @@ def load_model(prefer_gpu=True, preferred_gpus=[0, 1]):
|
|||||||
|
|
||||||
|
|
||||||
def detect(frame, threshold=0.4):
|
def detect(frame, threshold=0.4):
|
||||||
"""检测并识别人脸,返回结果元组(是否匹配到已知人脸, 结果字符串)"""
|
"""检测并识别人脸、返回结果元组(是否匹配到已知人脸, 结果字符串)"""
|
||||||
global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id
|
global _face_app, _known_faces_embeddings, _known_faces_names, _using_gpu, _used_gpu_id
|
||||||
global _ref_count, _last_used_time
|
global _ref_count, _last_used_time
|
||||||
|
|
||||||
@ -248,7 +248,7 @@ def detect(frame, threshold=0.4):
|
|||||||
return (False, "人脸识别引擎不可用或未初始化")
|
return (False, "人脸识别引擎不可用或未初始化")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 如果使用GPU,确保输入帧在处理前是连续的数组
|
# 如果使用GPU、确保输入帧在处理前是连续的数组
|
||||||
if _using_gpu and not frame.flags.contiguous:
|
if _using_gpu and not frame.flags.contiguous:
|
||||||
frame = np.ascontiguousarray(frame)
|
frame = np.ascontiguousarray(frame)
|
||||||
|
|
||||||
@ -285,7 +285,7 @@ def detect(frame, threshold=0.4):
|
|||||||
# 判断匹配结果
|
# 判断匹配结果
|
||||||
is_match = max_sim >= threshold
|
is_match = max_sim >= threshold
|
||||||
if is_match:
|
if is_match:
|
||||||
has_matched = True # 只要有一个匹配成功,就标记为True
|
has_matched = True # 只要有一个匹配成功、就标记为True
|
||||||
|
|
||||||
bbox = face.bbox
|
bbox = face.bbox
|
||||||
result_parts.append(
|
result_parts.append(
|
||||||
@ -298,12 +298,12 @@ def detect(frame, threshold=0.4):
|
|||||||
else:
|
else:
|
||||||
result_str = "; ".join(result_parts)
|
result_str = "; ".join(result_parts)
|
||||||
|
|
||||||
# 减少引用计数,确保线程安全
|
# 减少引用计数、确保线程安全
|
||||||
with _lock:
|
with _lock:
|
||||||
_ref_count = max(0, _ref_count - 1)
|
_ref_count = max(0, _ref_count - 1)
|
||||||
# 持续使用时更新最后使用时间
|
# 持续使用时更新最后使用时间
|
||||||
if _ref_count > 0:
|
if _ref_count > 0:
|
||||||
_last_used_time = time.time()
|
_last_used_time = time.time()
|
||||||
|
|
||||||
# 第一个返回值为:是否匹配到已知人脸
|
# 第一个返回值为: 是否匹配到已知人脸
|
||||||
return (has_matched, result_str)
|
return (has_matched, result_str)
|
@ -61,12 +61,12 @@ def _release_engine():
|
|||||||
|
|
||||||
|
|
||||||
def _monitor_thread():
|
def _monitor_thread():
|
||||||
"""监控线程,优化检查逻辑"""
|
"""监控线程、优化检查逻辑"""
|
||||||
global _ref_count, _last_used_time, _ocr_engine
|
global _ref_count, _last_used_time, _ocr_engine
|
||||||
while True:
|
while True:
|
||||||
time.sleep(5) # 每5秒检查一次
|
time.sleep(5) # 每5秒检查一次
|
||||||
with _lock:
|
with _lock:
|
||||||
# 只有当引擎存在、没有引用且超时,才释放
|
# 只有当引擎存在、没有引用且超时、才释放
|
||||||
if _ocr_engine and _ref_count == 0 and not _is_releasing:
|
if _ocr_engine and _ref_count == 0 and not _is_releasing:
|
||||||
elapsed = time.time() - _last_used_time
|
elapsed = time.time() - _last_used_time
|
||||||
if elapsed > _release_timeout:
|
if elapsed > _release_timeout:
|
||||||
@ -100,7 +100,7 @@ def load_model():
|
|||||||
|
|
||||||
|
|
||||||
def detect(frame):
|
def detect(frame):
|
||||||
"""OCR检测,优化引用计数管理"""
|
"""OCR检测、优化引用计数管理"""
|
||||||
global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time
|
global _ocr_engine, _forbidden_words, _conf_threshold, _ref_count, _last_used_time
|
||||||
|
|
||||||
# 验证前置条件
|
# 验证前置条件
|
||||||
@ -178,7 +178,7 @@ def detect(frame):
|
|||||||
return (False, f"检测错误: {str(e)}")
|
return (False, f"检测错误: {str(e)}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# 减少引用计数,确保线程安全
|
# 减少引用计数、确保线程安全
|
||||||
with _lock:
|
with _lock:
|
||||||
_ref_count = max(0, _ref_count - 1)
|
_ref_count = max(0, _ref_count - 1)
|
||||||
# 持续使用时更新最后使用时间
|
# 持续使用时更新最后使用时间
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import cv2
|
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
|
||||||
# 全局变量
|
# 全局变量
|
||||||
@ -24,7 +23,7 @@ def load_model():
|
|||||||
|
|
||||||
|
|
||||||
def detect(frame, conf_threshold=0.2):
|
def detect(frame, conf_threshold=0.2):
|
||||||
"""YOLO目标检测,返回(是否识别到, 结果字符串)"""
|
"""YOLO目标检测、返回(是否识别到, 结果字符串)"""
|
||||||
global _yolo_model
|
global _yolo_model
|
||||||
|
|
||||||
if not _yolo_model or frame is None:
|
if not _yolo_model or frame is None:
|
||||||
|
@ -14,4 +14,3 @@ config.read(config_path, encoding="utf-8")
|
|||||||
SERVER_CONFIG = config["server"]
|
SERVER_CONFIG = config["server"]
|
||||||
MYSQL_CONFIG = config["mysql"]
|
MYSQL_CONFIG = config["mysql"]
|
||||||
JWT_CONFIG = config["jwt"]
|
JWT_CONFIG = config["jwt"]
|
||||||
LIVE_CONFIG = config["live"]
|
|
||||||
|
2
main.py
2
main.py
@ -47,7 +47,7 @@ if __name__ == "__main__":
|
|||||||
YOLO_MODEL_PATH = r"/core/models\best.pt"
|
YOLO_MODEL_PATH = r"/core/models\best.pt"
|
||||||
OCR_CONFIG_PATH = r"/core/config\config.yaml"
|
OCR_CONFIG_PATH = r"/core/config\config.yaml"
|
||||||
|
|
||||||
# 初始化项目(默认端口设为8000,避免初始化失败时port未定义)
|
# 初始化项目(默认端口设为8000、避免初始化失败时port未定义)
|
||||||
port = int(SERVER_CONFIG.get("port", 8000))
|
port = int(SERVER_CONFIG.get("port", 8000))
|
||||||
|
|
||||||
# 启动 UVicorn 服务
|
# 启动 UVicorn 服务
|
||||||
|
@ -9,8 +9,6 @@ from passlib.context import CryptContext
|
|||||||
from ds.config import JWT_CONFIG
|
from ds.config import JWT_CONFIG
|
||||||
from ds.db import db
|
from ds.db import db
|
||||||
|
|
||||||
# 移除这里的 from service.user_service import UserResponse 导入
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 密码加密配置
|
# 密码加密配置
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@ -23,7 +21,7 @@ SECRET_KEY = JWT_CONFIG["secret_key"]
|
|||||||
ALGORITHM = JWT_CONFIG["algorithm"]
|
ALGORITHM = JWT_CONFIG["algorithm"]
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
|
ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
|
||||||
|
|
||||||
# OAuth2 依赖(从请求头获取 Token、格式:Bearer <token>)
|
# OAuth2 依赖(从请求头获取 Token、格式: Bearer <token>)
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
|
||||||
|
|
||||||
|
|
||||||
@ -63,7 +61,7 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
|
|||||||
# ------------------------------
|
# ------------------------------
|
||||||
def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型注解
|
def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型注解
|
||||||
"""从 Token 中解析用户信息、验证通过后返回当前用户"""
|
"""从 Token 中解析用户信息、验证通过后返回当前用户"""
|
||||||
# 延迟导入,打破循环依赖
|
# 延迟导入、打破循环依赖
|
||||||
from schema.user_schema import UserResponse # 在这里导入
|
from schema.user_schema import UserResponse # 在这里导入
|
||||||
|
|
||||||
# 认证失败异常
|
# 认证失败异常
|
||||||
@ -101,4 +99,4 @@ def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise credentials_exception from e
|
raise credentials_exception from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
@ -8,7 +8,7 @@ from schema.response_schema import APIResponse
|
|||||||
|
|
||||||
|
|
||||||
async def global_exception_handler(request: Request, exc: Exception):
|
async def global_exception_handler(request: Request, exc: Exception):
|
||||||
"""全局异常处理器:所有未捕获的异常都会在这里统一处理"""
|
"""全局异常处理器: 所有未捕获的异常都会在这里统一处理"""
|
||||||
# 1. 请求参数验证错误(Pydantic 校验失败)
|
# 1. 请求参数验证错误(Pydantic 校验失败)
|
||||||
if isinstance(exc, RequestValidationError):
|
if isinstance(exc, RequestValidationError):
|
||||||
error_details = []
|
error_details = []
|
||||||
@ -18,7 +18,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
|||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
content=APIResponse(
|
content=APIResponse(
|
||||||
code=400,
|
code=400,
|
||||||
message=f"请求参数错误:{'; '.join(error_details)}",
|
message=f"请求参数错误: {'; '.join(error_details)}",
|
||||||
data=None
|
data=None
|
||||||
).model_dump()
|
).model_dump()
|
||||||
)
|
)
|
||||||
@ -52,7 +52,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
|||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
content=APIResponse(
|
content=APIResponse(
|
||||||
code=500,
|
code=500,
|
||||||
message=f"数据库错误:{str(exc)}",
|
message=f"数据库错误: {str(exc)}",
|
||||||
data=None
|
data=None
|
||||||
).model_dump()
|
).model_dump()
|
||||||
)
|
)
|
||||||
@ -62,7 +62,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
|||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
content=APIResponse(
|
content=APIResponse(
|
||||||
code=500,
|
code=500,
|
||||||
message=f"服务器内部错误:{str(exc)}",
|
message=f"服务器内部错误: {str(exc)}",
|
||||||
data=None
|
data=None
|
||||||
).model_dump()
|
).model_dump()
|
||||||
)
|
)
|
||||||
|
@ -4,12 +4,12 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 请求模型(新增记录用,极简)
|
# 请求模型
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
class DeviceActionCreate(BaseModel):
|
class DeviceActionCreate(BaseModel):
|
||||||
"""设备操作记录创建模型(0=离线,1=上线)"""
|
"""设备操作记录创建模型(0=离线、1=上线)"""
|
||||||
client_ip: str = Field(..., description="客户端IP")
|
client_ip: str = Field(..., description="客户端IP")
|
||||||
action: int = Field(..., ge=0, le=1, description="操作状态(0=离线,1=上线)")
|
action: int = Field(..., ge=0, le=1, description="操作状态(0=离线、1=上线)")
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@ -19,7 +19,7 @@ class DeviceActionResponse(BaseModel):
|
|||||||
"""设备操作记录响应模型(与自增表对齐)"""
|
"""设备操作记录响应模型(与自增表对齐)"""
|
||||||
id: int = Field(..., description="自增主键ID")
|
id: int = Field(..., description="自增主键ID")
|
||||||
client_ip: Optional[str] = Field(None, description="客户端IP")
|
client_ip: Optional[str] = Field(None, description="客户端IP")
|
||||||
action: Optional[int] = Field(None, description="操作状态(0=离线,1=上线)")
|
action: Optional[int] = Field(None, description="操作状态(0=离线、1=上线)")
|
||||||
created_at: datetime = Field(..., description="记录创建时间")
|
created_at: datetime = Field(..., description="记录创建时间")
|
||||||
updated_at: datetime = Field(..., description="记录更新时间")
|
updated_at: datetime = Field(..., description="记录更新时间")
|
||||||
|
|
||||||
|
@ -6,8 +6,8 @@ from pydantic import BaseModel, Field
|
|||||||
# 请求模型(前端传参校验)
|
# 请求模型(前端传参校验)
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
class FaceCreateRequest(BaseModel):
|
class FaceCreateRequest(BaseModel):
|
||||||
"""创建人脸记录请求模型(无需ID,由数据库自增)"""
|
"""创建人脸记录请求模型(无需ID、由数据库自增)"""
|
||||||
name: str = Field(None, max_length=255, description="名称(可选,最长255字符)")
|
name: str = Field(None, max_length=255, description="名称(可选、最长255字符)")
|
||||||
|
|
||||||
|
|
||||||
class FaceUpdateRequest(BaseModel):
|
class FaceUpdateRequest(BaseModel):
|
||||||
@ -20,7 +20,7 @@ class FaceUpdateRequest(BaseModel):
|
|||||||
# 响应模型(后端返回数据)
|
# 响应模型(后端返回数据)
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
class FaceResponse(BaseModel):
|
class FaceResponse(BaseModel):
|
||||||
"""人脸记录响应模型(仍包含ID,由数据库生成后返回)"""
|
"""人脸记录响应模型(仍包含ID、由数据库生成后返回)"""
|
||||||
id: int = Field(..., description="主键ID(数据库自增)")
|
id: int = Field(..., description="主键ID(数据库自增)")
|
||||||
name: str = Field(None, description="名称")
|
name: str = Field(None, description="名称")
|
||||||
eigenvalue: str | None = Field(None, description="特征(可为空)")
|
eigenvalue: str | None = Field(None, description="特征(可为空)")
|
||||||
|
@ -5,9 +5,9 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
class APIResponse(BaseModel):
|
class APIResponse(BaseModel):
|
||||||
"""统一 API 响应模型(所有接口必返此格式)"""
|
"""统一 API 响应模型(所有接口必返此格式)"""
|
||||||
code: int = Field(..., description="状态码:200=成功、4xx=客户端错误、5xx=服务端错误")
|
code: int = Field(..., description="状态码: 200=成功、4xx=客户端错误、5xx=服务端错误")
|
||||||
message: str = Field(..., description="响应信息:成功/错误描述")
|
message: str = Field(..., description="响应信息: 成功/错误描述")
|
||||||
data: Optional[Any] = Field(None, description="响应数据:成功时返回、错误时为 None")
|
data: Optional[Any] = Field(None, description="响应数据: 成功时返回、错误时为 None")
|
||||||
|
|
||||||
# Pydantic V2 配置(支持从 ORM 对象转换)
|
# Pydantic V2 配置(支持从 ORM 对象转换)
|
||||||
model_config = {"from_attributes": True}
|
model_config = {"from_attributes": True}
|
||||||
|
@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
|
|||||||
# ------------------------------
|
# ------------------------------
|
||||||
class SensitiveCreateRequest(BaseModel):
|
class SensitiveCreateRequest(BaseModel):
|
||||||
"""创建敏感信息记录请求模型"""
|
"""创建敏感信息记录请求模型"""
|
||||||
# 移除了id字段,由数据库自动生成
|
# 移除了id字段、由数据库自动生成
|
||||||
name: str = Field(None, max_length=255, description="名称")
|
name: str = Field(None, max_length=255, description="名称")
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from fastapi import APIRouter, Query
|
from fastapi import APIRouter, Query, Path
|
||||||
from mysql.connector import Error as MySQLError
|
from mysql.connector import Error as MySQLError
|
||||||
|
|
||||||
from ds.db import db
|
from ds.db import db
|
||||||
@ -9,7 +9,6 @@ from schema.device_action_schema import (
|
|||||||
)
|
)
|
||||||
from schema.response_schema import APIResponse
|
from schema.response_schema import APIResponse
|
||||||
|
|
||||||
|
|
||||||
# 路由配置
|
# 路由配置
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/device/actions",
|
prefix="/device/actions",
|
||||||
@ -18,11 +17,11 @@ router = APIRouter(
|
|||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 内部方法:新增设备操作记录(适配id自增)
|
# 内部方法: 新增设备操作记录(适配id自增)
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
def add_device_action(action_data: DeviceActionCreate) -> DeviceActionResponse:
|
def add_device_action(action_data: DeviceActionCreate) -> DeviceActionResponse:
|
||||||
"""
|
"""
|
||||||
新增设备操作记录(内部方法,非接口)
|
新增设备操作记录(内部方法、非接口)
|
||||||
:param action_data: 含client_ip和action(0/1)
|
:param action_data: 含client_ip和action(0/1)
|
||||||
:return: 新增的完整记录
|
:return: 新增的完整记录
|
||||||
"""
|
"""
|
||||||
@ -32,7 +31,7 @@ def add_device_action(action_data: DeviceActionCreate) -> DeviceActionResponse:
|
|||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
# 插入SQL(id自增,依赖数据库自动生成)
|
# 插入SQL(id自增、依赖数据库自动生成)
|
||||||
insert_query = """
|
insert_query = """
|
||||||
INSERT INTO device_action
|
INSERT INTO device_action
|
||||||
(client_ip, action, created_at, updated_at)
|
(client_ip, action, created_at, updated_at)
|
||||||
@ -54,20 +53,20 @@ def add_device_action(action_data: DeviceActionCreate) -> DeviceActionResponse:
|
|||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise Exception(f"新增记录失败:{str(e)}") from e
|
raise Exception(f"新增记录失败: {str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 接口:分页查询操作记录列表(仅返回 total + device_actions)
|
# 接口: 分页查询操作记录列表(仅返回 total + device_actions)
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@router.get("/list", response_model=APIResponse, summary="分页查询设备操作记录")
|
@router.get("/list", response_model=APIResponse, summary="分页查询设备操作记录")
|
||||||
async def get_device_action_list(
|
async def get_device_action_list(
|
||||||
page: int = Query(1, ge=1, description="页码,默认1"),
|
page: int = Query(1, ge=1, description="页码、默认1"),
|
||||||
page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100"),
|
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100"),
|
||||||
client_ip: str = Query(None, description="按客户端IP筛选"),
|
client_ip: str = Query(None, description="按客户端IP筛选"),
|
||||||
action: int = Query(None, ge=0, le=1, description="按状态筛选(0=离线,1=上线)")
|
action: int = Query(None, ge=0, le=1, description="按状态筛选(0=离线、1=上线)")
|
||||||
):
|
):
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
@ -75,7 +74,7 @@ async def get_device_action_list(
|
|||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
# 1. 构建筛选条件(参数化查询,避免注入)
|
# 1. 构建筛选条件(参数化查询、避免注入)
|
||||||
where_clause = []
|
where_clause = []
|
||||||
params = []
|
params = []
|
||||||
if client_ip:
|
if client_ip:
|
||||||
@ -92,13 +91,13 @@ async def get_device_action_list(
|
|||||||
cursor.execute(count_sql, params)
|
cursor.execute(count_sql, params)
|
||||||
total = cursor.fetchone()["total"]
|
total = cursor.fetchone()["total"]
|
||||||
|
|
||||||
# 3. 分页查询记录(按创建时间倒序,确保最新记录在前)
|
# 3. 分页查询记录(按创建时间倒序、确保最新记录在前)
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
list_sql = "SELECT * FROM device_action"
|
list_sql = "SELECT * FROM device_action"
|
||||||
if where_clause:
|
if where_clause:
|
||||||
list_sql += " WHERE " + " AND ".join(where_clause)
|
list_sql += " WHERE " + " AND ".join(where_clause)
|
||||||
list_sql += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
|
list_sql += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
|
||||||
params.extend([page_size, offset]) # 追加分页参数(page/page_size仅用于查询,不返回)
|
params.extend([page_size, offset]) # 追加分页参数(page/page_size仅用于查询、不返回)
|
||||||
|
|
||||||
cursor.execute(list_sql, params)
|
cursor.execute(list_sql, params)
|
||||||
action_list = cursor.fetchall()
|
action_list = cursor.fetchall()
|
||||||
@ -114,6 +113,46 @@ async def get_device_action_list(
|
|||||||
)
|
)
|
||||||
|
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
raise Exception(f"查询记录失败:{str(e)}") from e
|
raise Exception(f"查询记录失败: {str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{client_ip}", response_model=APIResponse, summary="根据IP查询设备操作记录")
|
||||||
|
async def get_device_actions_by_ip(
|
||||||
|
client_ip: str = Path(..., description="客户端IP地址")
|
||||||
|
):
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
# 1. 查询总记录数
|
||||||
|
count_sql = "SELECT COUNT(*) AS total FROM device_action WHERE client_ip = %s"
|
||||||
|
cursor.execute(count_sql, (client_ip,))
|
||||||
|
total = cursor.fetchone()["total"]
|
||||||
|
|
||||||
|
# 2. 查询该IP的所有记录(按创建时间倒序)
|
||||||
|
list_sql = """
|
||||||
|
SELECT * FROM device_action
|
||||||
|
WHERE client_ip = %s
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
"""
|
||||||
|
cursor.execute(list_sql, (client_ip,))
|
||||||
|
action_list = cursor.fetchall()
|
||||||
|
|
||||||
|
# 3. 返回结果
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message="查询成功",
|
||||||
|
data=DeviceActionListResponse(
|
||||||
|
total=total,
|
||||||
|
device_actions=[DeviceActionResponse(**item) for item in action_list]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except MySQLError as e:
|
||||||
|
raise Exception(f"查询记录失败: {str(e)}") from e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
@ -37,7 +37,7 @@ def increment_alarm_count_by_ip(client_ip: str) -> bool:
|
|||||||
if not cursor.fetchone():
|
if not cursor.fetchone():
|
||||||
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
|
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
|
||||||
|
|
||||||
# 报警次数加1,并更新时间戳
|
# 报警次数加1、并更新时间戳
|
||||||
update_query = """
|
update_query = """
|
||||||
UPDATE devices
|
UPDATE devices
|
||||||
SET alarm_count = alarm_count + 1,
|
SET alarm_count = alarm_count + 1,
|
||||||
@ -51,7 +51,7 @@ def increment_alarm_count_by_ip(client_ip: str) -> bool:
|
|||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise Exception(f"更新报警次数失败:{str(e)}") from e
|
raise Exception(f"更新报警次数失败: {str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
@ -99,7 +99,7 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
|
|||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise Exception(f"更新设备在线状态失败:{str(e)}") from e
|
raise Exception(f"更新设备在线状态失败: {str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
@ -124,7 +124,7 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): #
|
|||||||
# 返回信息
|
# 返回信息
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=200,
|
code=200,
|
||||||
message=f"设备IP {device_data.ip} 已存在,返回已有设备信息",
|
message=f"设备IP {device_data.ip} 已存在、返回已有设备信息",
|
||||||
data=DeviceResponse(** existing_device)
|
data=DeviceResponse(** existing_device)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -173,9 +173,9 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): #
|
|||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise Exception(f"创建设备失败:{str(e)}") from e
|
raise Exception(f"创建设备失败: {str(e)}") from e
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise Exception(f"设备详细信息JSON序列化失败:{str(e)}") from e
|
raise Exception(f"设备详细信息JSON序列化失败: {str(e)}") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
@ -185,8 +185,8 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): #
|
|||||||
|
|
||||||
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
|
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
|
||||||
async def get_device_list(
|
async def get_device_list(
|
||||||
page: int = Query(1, ge=1, description="页码,默认第1页"),
|
page: int = Query(1, ge=1, description="页码、默认第1页"),
|
||||||
page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"),
|
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
|
||||||
device_type: str = Query(None, description="按设备类型筛选"),
|
device_type: str = Query(None, description="按设备类型筛选"),
|
||||||
online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选")
|
online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选")
|
||||||
):
|
):
|
||||||
@ -233,6 +233,6 @@ async def get_device_list(
|
|||||||
)
|
)
|
||||||
|
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
raise Exception(f"获取设备列表失败:{str(e)}") from e
|
raise Exception(f"获取设备列表失败: {str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
@ -18,27 +18,27 @@ router = APIRouter(
|
|||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 1. 创建人脸记录(核心修正:ID 数据库自增,前端无需传)
|
# 1. 创建人脸记录(核心修正: ID 数据库自增、前端无需传)
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件,ID自增)")
|
@router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件、ID自增)")
|
||||||
async def create_face(
|
async def create_face(
|
||||||
# 前端仅需传:name(可选,Form格式)、file(必传,文件)
|
# 前端仅需传: name(可选、Form格式)、file(必传、文件)
|
||||||
name: str = Form(None, max_length=255, description="名称(可选)"),
|
name: str = Form(None, max_length=255, description="名称(可选)"),
|
||||||
file: UploadFile = File(..., description="人脸文件(必传,暂不处理内容)")
|
file: UploadFile = File(..., description="人脸文件(必传、暂不处理内容)")
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
创建人脸记录:
|
创建人脸记录:
|
||||||
- 需登录认证
|
- 需登录认证
|
||||||
- 前端传参:multipart/form-data 表单(name 可选,file 必传)
|
- 前端传参: multipart/form-data 表单(name 可选、file 必传)
|
||||||
- ID 由数据库自动生成,无需前端传入
|
- ID 由数据库自动生成、无需前端传入
|
||||||
- 暂不处理文件内容,eigenvalue 设为 None
|
- 暂不处理文件内容、eigenvalue 设为 None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 调用你的方法
|
# 调用你的方法
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
try:
|
try:
|
||||||
# 1. 用模型校验 name(仅校验长度,无需ID)
|
# 1. 用模型校验 name(仅校验长度、无需ID)
|
||||||
face_create = FaceCreateRequest(name=name)
|
face_create = FaceCreateRequest(name=name)
|
||||||
|
|
||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
@ -57,9 +57,9 @@ async def create_face(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 打印数组长度
|
# 打印数组长度
|
||||||
print(f"文件大小:{len(file_content)} 字节")
|
print(f"文件大小: {len(file_content)} 字节")
|
||||||
|
|
||||||
# 2. 插入数据库:无需传 ID(自增),只传 name 和 eigenvalue(None)
|
# 2. 插入数据库: 无需传 ID(自增)、只传 name 和 eigenvalue(None)
|
||||||
insert_query = """
|
insert_query = """
|
||||||
INSERT INTO face (name, eigenvalue)
|
INSERT INTO face (name, eigenvalue)
|
||||||
VALUES (%s, %s)
|
VALUES (%s, %s)
|
||||||
@ -67,7 +67,7 @@ async def create_face(
|
|||||||
cursor.execute(insert_query, (face_create.name, str(eigenvalue)))
|
cursor.execute(insert_query, (face_create.name, str(eigenvalue)))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
# 3. 获取数据库自动生成的 ID(关键:用 LAST_INSERT_ID() 查刚插入的记录)
|
# 3. 获取数据库自动生成的 ID(关键: 用 LAST_INSERT_ID() 查刚插入的记录)
|
||||||
select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()"
|
select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()"
|
||||||
cursor.execute(select_new_query)
|
cursor.execute(select_new_query)
|
||||||
created_face = cursor.fetchone()
|
created_face = cursor.fetchone()
|
||||||
@ -75,12 +75,12 @@ async def create_face(
|
|||||||
if not created_face:
|
if not created_face:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail="创建人脸记录成功,但无法获取新创建的记录"
|
detail="创建人脸记录成功、但无法获取新创建的记录"
|
||||||
)
|
)
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=201,
|
code=201,
|
||||||
message=f"人脸记录创建成功(ID:{created_face['id']},文件名:{file.filename})",
|
message=f"人脸记录创建成功(ID: {created_face['id']}、文件名: {file.filename})",
|
||||||
data=FaceResponse(** created_face)
|
data=FaceResponse(** created_face)
|
||||||
)
|
)
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
@ -89,13 +89,13 @@ async def create_face(
|
|||||||
# 改为使用HTTPException
|
# 改为使用HTTPException
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"创建人脸记录失败:{str(e)}"
|
detail=f"创建人脸记录失败: {str(e)}"
|
||||||
) from e
|
) from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 捕获其他可能的异常
|
# 捕获其他可能的异常
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"服务器错误:{str(e)}"
|
detail=f"服务器错误: {str(e)}"
|
||||||
) from e
|
) from e
|
||||||
finally:
|
finally:
|
||||||
await file.close() # 关闭文件流
|
await file.close() # 关闭文件流
|
||||||
@ -113,11 +113,11 @@ async def create_face(
|
|||||||
eigenvalue = str(eigenvalue)
|
eigenvalue = str(eigenvalue)
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 2. 获取单个人脸记录(不变,用自增ID查询)
|
# 2. 获取单个人脸记录(不变、用自增ID查询)
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
|
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
|
||||||
async def get_face(
|
async def get_face(
|
||||||
face_id: int, # 这里的 ID 是数据库自增的,前端从创建响应中获取
|
face_id: int, # 这里的 ID 是数据库自增的、前端从创建响应中获取
|
||||||
current_user: UserResponse = Depends(get_current_user)
|
current_user: UserResponse = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
conn = None
|
conn = None
|
||||||
@ -145,7 +145,7 @@ async def get_face(
|
|||||||
# 改为使用HTTPException
|
# 改为使用HTTPException
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"查询人脸记录失败:{str(e)}"
|
detail=f"查询人脸记录失败: {str(e)}"
|
||||||
) from e
|
) from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
@ -176,14 +176,14 @@ async def get_all_faces(
|
|||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"查询所有人脸记录失败:{str(e)}"
|
detail=f"查询所有人脸记录失败: {str(e)}"
|
||||||
) from e
|
) from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 4. 更新人脸记录(不变,用自增ID更新)
|
# 4. 更新人脸记录(不变、用自增ID更新)
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录")
|
@router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录")
|
||||||
async def update_face(
|
async def update_face(
|
||||||
@ -240,14 +240,14 @@ async def update_face(
|
|||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"更新人脸记录失败:{str(e)}"
|
detail=f"更新人脸记录失败: {str(e)}"
|
||||||
) from e
|
) from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 5. 删除人脸记录(不变,用自增ID删除)
|
# 5. 删除人脸记录(不变、用自增ID删除)
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
|
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
|
||||||
async def delete_face(
|
async def delete_face(
|
||||||
@ -283,7 +283,7 @@ async def delete_face(
|
|||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"删除人脸记录失败:{str(e)}"
|
detail=f"删除人脸记录失败: {str(e)}"
|
||||||
) from e
|
) from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
@ -291,10 +291,10 @@ async def delete_face(
|
|||||||
|
|
||||||
def get_all_face_name_with_eigenvalue() -> dict:
|
def get_all_face_name_with_eigenvalue() -> dict:
|
||||||
"""
|
"""
|
||||||
获取所有人脸的名称及其对应的特征值,组成字典返回
|
获取所有人脸的名称及其对应的特征值、组成字典返回
|
||||||
key: 人脸名称(name)
|
key: 人脸名称(name)
|
||||||
value: 人脸特征值(eigenvalue),若名称重复则返回平均特征值
|
value: 人脸特征值(eigenvalue)、若名称重复则返回平均特征值
|
||||||
注:过滤掉name为None的记录,避免字典key为None的情况
|
注: 过滤掉name为None的记录、避免字典key为None的情况
|
||||||
"""
|
"""
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
@ -303,27 +303,27 @@ def get_all_face_name_with_eigenvalue() -> dict:
|
|||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
# 2. 执行SQL查询:只获取name非空的记录,减少数据传输
|
# 2. 执行SQL查询: 只获取name非空的记录、减少数据传输
|
||||||
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
|
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
faces = cursor.fetchall() # 返回结果:列表套字典,如 [{"name":"张三","eigenvalue":...}, ...]
|
faces = cursor.fetchall() # 返回结果: 列表套字典、如 [{"name":"张三","eigenvalue":...}, ...]
|
||||||
|
|
||||||
# 3. 收集同一名称对应的所有特征值(处理名称重复场景)
|
# 3. 收集同一名称对应的所有特征值(处理名称重复场景)
|
||||||
name_to_eigenvalues = {}
|
name_to_eigenvalues = {}
|
||||||
for face in faces:
|
for face in faces:
|
||||||
name = face["name"]
|
name = face["name"]
|
||||||
eigenvalue = face["eigenvalue"]
|
eigenvalue = face["eigenvalue"]
|
||||||
# 若名称已存在,追加特征值;否则新建列表存储
|
# 若名称已存在、追加特征值;否则新建列表存储
|
||||||
if name in name_to_eigenvalues:
|
if name in name_to_eigenvalues:
|
||||||
name_to_eigenvalues[name].append(eigenvalue)
|
name_to_eigenvalues[name].append(eigenvalue)
|
||||||
else:
|
else:
|
||||||
name_to_eigenvalues[name] = [eigenvalue]
|
name_to_eigenvalues[name] = [eigenvalue]
|
||||||
|
|
||||||
# 4. 构建最终字典:重复名称取平均,唯一名称直接取特征值
|
# 4. 构建最终字典: 重复名称取平均、唯一名称直接取特征值
|
||||||
face_dict = {}
|
face_dict = {}
|
||||||
for name, eigenvalues in name_to_eigenvalues.items():
|
for name, eigenvalues in name_to_eigenvalues.items():
|
||||||
|
|
||||||
# 处理特征值:多个则求平均,单个则直接使用
|
# 处理特征值: 多个则求平均、单个则直接使用
|
||||||
if len(eigenvalues) > 1:
|
if len(eigenvalues) > 1:
|
||||||
# 调用外部方法计算平均特征值(需确保binary_face_feature_handler已正确导入)
|
# 调用外部方法计算平均特征值(需确保binary_face_feature_handler已正确导入)
|
||||||
face_dict[name] = get_average_feature(eigenvalues)
|
face_dict[name] = get_average_feature(eigenvalues)
|
||||||
@ -334,8 +334,8 @@ def get_all_face_name_with_eigenvalue() -> dict:
|
|||||||
return face_dict
|
return face_dict
|
||||||
|
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
# 捕获数据库异常,添加上下文信息后重新抛出(便于定位问题)
|
# 捕获数据库异常、添加上下文信息后重新抛出(便于定位问题)
|
||||||
raise Exception(f"获取人脸名称与特征值失败:{str(e)}") from e
|
raise Exception(f"获取人脸名称与特征值失败: {str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
# 5. 无论是否异常,均释放数据库连接和游标(避免资源泄漏)
|
# 5. 无论是否异常、均释放数据库连接和游标(避免资源泄漏)
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
@ -21,7 +21,7 @@ router = APIRouter(
|
|||||||
async def create_sensitive(
|
async def create_sensitive(
|
||||||
sensitive: SensitiveCreateRequest): # 添加了登录认证依赖
|
sensitive: SensitiveCreateRequest): # 添加了登录认证依赖
|
||||||
"""
|
"""
|
||||||
创建敏感信息记录:
|
创建敏感信息记录:
|
||||||
- 需登录认证
|
- 需登录认证
|
||||||
- 插入新的敏感信息记录到数据库(ID由数据库自动生成)
|
- 插入新的敏感信息记录到数据库(ID由数据库自动生成)
|
||||||
- 返回创建成功信息
|
- 返回创建成功信息
|
||||||
@ -32,7 +32,7 @@ async def create_sensitive(
|
|||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
# 插入新敏感信息记录到数据库(不包含ID,由数据库自动生成)
|
# 插入新敏感信息记录到数据库(不包含ID、由数据库自动生成)
|
||||||
insert_query = """
|
insert_query = """
|
||||||
INSERT INTO sensitives (name)
|
INSERT INTO sensitives (name)
|
||||||
VALUES (%s)
|
VALUES (%s)
|
||||||
@ -56,7 +56,7 @@ async def create_sensitive(
|
|||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise Exception(f"创建敏感信息记录失败:{str(e)}") from e
|
raise Exception(f"创建敏感信息记录失败: {str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
@ -71,7 +71,7 @@ async def get_sensitive(
|
|||||||
current_user: UserResponse = Depends(get_current_user) # 需登录认证
|
current_user: UserResponse = Depends(get_current_user) # 需登录认证
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取单个敏感信息记录:
|
获取单个敏感信息记录:
|
||||||
- 需登录认证
|
- 需登录认证
|
||||||
- 根据ID查询敏感信息记录
|
- 根据ID查询敏感信息记录
|
||||||
- 返回查询到的敏感信息
|
- 返回查询到的敏感信息
|
||||||
@ -98,7 +98,7 @@ async def get_sensitive(
|
|||||||
data=SensitiveResponse(**sensitive)
|
data=SensitiveResponse(**sensitive)
|
||||||
)
|
)
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
raise Exception(f"查询敏感信息记录失败:{str(e)}") from e
|
raise Exception(f"查询敏感信息记录失败: {str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
@ -109,7 +109,7 @@ async def get_sensitive(
|
|||||||
@router.get("", response_model=APIResponse, summary="获取所有敏感信息记录")
|
@router.get("", response_model=APIResponse, summary="获取所有敏感信息记录")
|
||||||
async def get_all_sensitives():
|
async def get_all_sensitives():
|
||||||
"""
|
"""
|
||||||
获取所有敏感信息记录:
|
获取所有敏感信息记录:
|
||||||
- 需登录认证
|
- 需登录认证
|
||||||
- 查询所有敏感信息记录(不需要分页)
|
- 查询所有敏感信息记录(不需要分页)
|
||||||
- 返回所有敏感信息列表
|
- 返回所有敏感信息列表
|
||||||
@ -130,7 +130,7 @@ async def get_all_sensitives():
|
|||||||
data=[SensitiveResponse(**sensitive) for sensitive in sensitives]
|
data=[SensitiveResponse(**sensitive) for sensitive in sensitives]
|
||||||
)
|
)
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
raise Exception(f"查询所有敏感信息记录失败:{str(e)}") from e
|
raise Exception(f"查询所有敏感信息记录失败: {str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
@ -145,7 +145,7 @@ async def update_sensitive(
|
|||||||
current_user: UserResponse = Depends(get_current_user) # 需登录认证
|
current_user: UserResponse = Depends(get_current_user) # 需登录认证
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
更新敏感信息记录:
|
更新敏感信息记录:
|
||||||
- 需登录认证
|
- 需登录认证
|
||||||
- 根据ID更新敏感信息记录
|
- 根据ID更新敏感信息记录
|
||||||
- 返回更新后的敏感信息
|
- 返回更新后的敏感信息
|
||||||
@ -203,7 +203,7 @@ async def update_sensitive(
|
|||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise Exception(f"更新敏感信息记录失败:{str(e)}") from e
|
raise Exception(f"更新敏感信息记录失败: {str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
@ -217,7 +217,7 @@ async def delete_sensitive(
|
|||||||
current_user: UserResponse = Depends(get_current_user) # 需登录认证
|
current_user: UserResponse = Depends(get_current_user) # 需登录认证
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
删除敏感信息记录:
|
删除敏感信息记录:
|
||||||
- 需登录认证
|
- 需登录认证
|
||||||
- 根据ID删除敏感信息记录
|
- 根据ID删除敏感信息记录
|
||||||
- 返回删除成功信息
|
- 返回删除成功信息
|
||||||
@ -251,14 +251,14 @@ async def delete_sensitive(
|
|||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise Exception(f"删除敏感信息记录失败:{str(e)}") from e
|
raise Exception(f"删除敏感信息记录失败: {str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
def get_all_sensitive_words() -> list[str]:
|
def get_all_sensitive_words() -> list[str]:
|
||||||
"""
|
"""
|
||||||
获取所有敏感词,返回字符串数组
|
获取所有敏感词、返回字符串数组
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
list[str]: 包含所有敏感词的数组
|
list[str]: 包含所有敏感词的数组
|
||||||
@ -273,7 +273,7 @@ def get_all_sensitive_words() -> list[str]:
|
|||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
# 执行查询,只获取敏感词字段
|
# 执行查询、只获取敏感词字段
|
||||||
query = "SELECT name FROM sensitives ORDER BY id"
|
query = "SELECT name FROM sensitives ORDER BY id"
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
sensitive_records = cursor.fetchall()
|
sensitive_records = cursor.fetchall()
|
||||||
@ -283,7 +283,7 @@ def get_all_sensitive_words() -> list[str]:
|
|||||||
|
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
# 数据库错误处理
|
# 数据库错误处理
|
||||||
raise MySQLError(f"查询敏感词失败:{str(e)}") from e
|
raise MySQLError(f"查询敏感词失败: {str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
# 确保资源正确释放
|
# 确保资源正确释放
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
@ -27,7 +27,7 @@ router = APIRouter(
|
|||||||
@router.post("/register", response_model=APIResponse, summary="用户注册")
|
@router.post("/register", response_model=APIResponse, summary="用户注册")
|
||||||
async def user_register(request: UserRegisterRequest):
|
async def user_register(request: UserRegisterRequest):
|
||||||
"""
|
"""
|
||||||
用户注册:
|
用户注册:
|
||||||
- 校验用户名是否已存在
|
- 校验用户名是否已存在
|
||||||
- 加密密码后插入数据库
|
- 加密密码后插入数据库
|
||||||
- 返回注册成功信息
|
- 返回注册成功信息
|
||||||
@ -67,7 +67,7 @@ async def user_register(request: UserRegisterRequest):
|
|||||||
)
|
)
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
conn.rollback() # 数据库错误时回滚事务
|
conn.rollback() # 数据库错误时回滚事务
|
||||||
raise Exception(f"注册失败:{str(e)}") from e
|
raise Exception(f"注册失败: {str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ async def user_register(request: UserRegisterRequest):
|
|||||||
@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token)")
|
@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token)")
|
||||||
async def user_login(request: UserLoginRequest):
|
async def user_login(request: UserLoginRequest):
|
||||||
"""
|
"""
|
||||||
用户登录:
|
用户登录:
|
||||||
- 校验用户名是否存在
|
- 校验用户名是否存在
|
||||||
- 校验密码是否正确
|
- 校验密码是否正确
|
||||||
- 生成 JWT Token 并返回
|
- 生成 JWT Token 并返回
|
||||||
@ -89,7 +89,7 @@ async def user_login(request: UserLoginRequest):
|
|||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
# 修复:SQL查询添加 created_at 和 updated_at 字段
|
# 修复: SQL查询添加 created_at 和 updated_at 字段
|
||||||
query = """
|
query = """
|
||||||
SELECT id, username, password, created_at, updated_at
|
SELECT id, username, password, created_at, updated_at
|
||||||
FROM users
|
FROM users
|
||||||
@ -129,7 +129,7 @@ async def user_login(request: UserLoginRequest):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
raise Exception(f"登录失败:{str(e)}") from e
|
raise Exception(f"登录失败: {str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
@ -142,8 +142,8 @@ async def get_current_user_info(
|
|||||||
current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件
|
current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取当前登录用户信息:
|
获取当前登录用户信息:
|
||||||
- 需在请求头携带 Token(格式:Bearer <token>)
|
- 需在请求头携带 Token(格式: Bearer <token>)
|
||||||
- 认证通过后返回用户信息
|
- 认证通过后返回用户信息
|
||||||
"""
|
"""
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
|
@ -27,7 +27,7 @@ def init_insightface():
|
|||||||
|
|
||||||
def add_binary_data(binary_data):
|
def add_binary_data(binary_data):
|
||||||
"""
|
"""
|
||||||
接收单张图片的二进制数据,提取特征并保存
|
接收单张图片的二进制数据、提取特征并保存
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
binary_data: 图片的二进制数据(bytes类型)
|
binary_data: 图片的二进制数据(bytes类型)
|
||||||
@ -39,11 +39,11 @@ def add_binary_data(binary_data):
|
|||||||
global _insightface_app, _feature_list
|
global _insightface_app, _feature_list
|
||||||
|
|
||||||
if not _insightface_app:
|
if not _insightface_app:
|
||||||
print("引擎未初始化,无法处理")
|
print("引擎未初始化、无法处理")
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 直接处理二进制数据:转换为图像格式
|
# 直接处理二进制数据: 转换为图像格式
|
||||||
img = Image.open(BytesIO(binary_data))
|
img = Image.open(BytesIO(binary_data))
|
||||||
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
@ -70,15 +70,15 @@ def get_average_feature(features=None):
|
|||||||
计算多个特征向量的平均值
|
计算多个特征向量的平均值
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
features: 可选,特征值列表。如果未提供,则使用全局存储的_feature_list
|
features: 可选、特征值列表。如果未提供、则使用全局存储的_feature_list
|
||||||
每个元素可以是字符串格式或numpy数组
|
每个元素可以是字符串格式或numpy数组
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
单一平均特征向量的numpy数组,若无可计算数据则返回None
|
单一平均特征向量的numpy数组、若无可计算数据则返回None
|
||||||
"""
|
"""
|
||||||
global _feature_list
|
global _feature_list
|
||||||
|
|
||||||
# 如果未提供features参数,则使用全局特征列表
|
# 如果未提供features参数、则使用全局特征列表
|
||||||
if features is None:
|
if features is None:
|
||||||
features = _feature_list
|
features = _feature_list
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ def get_average_feature(features=None):
|
|||||||
processed_features.append(embedding_np)
|
processed_features.append(embedding_np)
|
||||||
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
|
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
|
||||||
else:
|
else:
|
||||||
print(f"跳过第 {i + 1} 个特征值,不是一维数组")
|
print(f"跳过第 {i + 1} 个特征值、不是一维数组")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"处理第 {i + 1} 个特征值时出错: {e}")
|
print(f"处理第 {i + 1} 个特征值时出错: {e}")
|
||||||
@ -118,12 +118,12 @@ def get_average_feature(features=None):
|
|||||||
# 检查所有特征向量维度是否相同
|
# 检查所有特征向量维度是否相同
|
||||||
dims = {feat.shape[0] for feat in processed_features}
|
dims = {feat.shape[0] for feat in processed_features}
|
||||||
if len(dims) > 1:
|
if len(dims) > 1:
|
||||||
print(f"特征值维度不一致,无法计算平均值。检测到的维度: {dims}")
|
print(f"特征值维度不一致、无法计算平均值。检测到的维度: {dims}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 计算平均值
|
# 计算平均值
|
||||||
avg_feature = np.mean(processed_features, axis=0)
|
avg_feature = np.mean(processed_features, axis=0)
|
||||||
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量,维度: {avg_feature.shape[0]}")
|
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量、维度: {avg_feature.shape[0]}")
|
||||||
|
|
||||||
return avg_feature
|
return avg_feature
|
||||||
|
|
||||||
|
80
ws/ws.py
80
ws/ws.py
@ -21,7 +21,7 @@ WS_ENDPOINT = "/ws" # WebSocket端点路径
|
|||||||
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
|
FRAME_QUEUE_SIZE = 1 # 帧队列大小限制
|
||||||
|
|
||||||
|
|
||||||
# 工具函数:获取格式化时间字符串(统一时间戳格式)
|
# 工具函数: 获取格式化时间字符串(统一时间戳格式)
|
||||||
def get_current_time_str() -> str:
|
def get_current_time_str() -> str:
|
||||||
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
@ -65,32 +65,32 @@ class ClientConnection:
|
|||||||
"client_ip": self.client_ip
|
"client_ip": self.client_ip
|
||||||
}
|
}
|
||||||
await self.websocket.send_json(frame_permit_msg)
|
await self.websocket.send_json(frame_permit_msg)
|
||||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:已发送帧发送许可信号(取帧后立即通知)")
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 已发送帧发送许可信号(取帧后立即通知)")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧许可信号发送失败 - {str(e)}")
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧许可信号发送失败 - {str(e)}")
|
||||||
|
|
||||||
async def consume_frames(self) -> None:
|
async def consume_frames(self) -> None:
|
||||||
"""消费队列中的帧并处理(核心调整:取帧后立即发许可,再处理帧)"""
|
"""消费队列中的帧并处理(核心调整: 取帧后立即发许可、再处理帧)"""
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
# 1. 从队列取出帧(阻塞直到有帧可用)
|
# 1. 从队列取出帧(阻塞直到有帧可用)
|
||||||
frame_data = await self.frame_queue.get()
|
frame_data = await self.frame_queue.get()
|
||||||
|
|
||||||
# -------------------------- 核心修改:取出帧后立即发送下一帧许可 --------------------------
|
# -------------------------- 核心修改: 取出帧后立即发送下一帧许可 --------------------------
|
||||||
await self.send_frame_permit() # 取帧即通知客户端发下一帧,无需等处理完成
|
await self.send_frame_permit() # 取帧即通知客户端发下一帧、无需等处理完成
|
||||||
# -----------------------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 2. 处理取出的帧(即使处理慢,客户端也已收到许可,可提前准备下一帧)
|
# 2. 处理取出的帧(即使处理慢、客户端也已收到许可、可提前准备下一帧)
|
||||||
await self.process_frame(frame_data)
|
await self.process_frame(frame_data)
|
||||||
finally:
|
finally:
|
||||||
# 3. 标记帧任务完成(无论处理成功/失败,都需清理队列)
|
# 3. 标记帧任务完成(无论处理成功/失败、都需清理队列)
|
||||||
self.frame_queue.task_done()
|
self.frame_queue.task_done()
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费任务已取消")
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费任务已取消")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:帧消费逻辑错误 - {str(e)}")
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 帧消费逻辑错误 - {str(e)}")
|
||||||
|
|
||||||
async def process_frame(self, frame_data: bytes) -> None:
|
async def process_frame(self, frame_data: bytes) -> None:
|
||||||
"""处理单帧图像数据(检测违规后发送危险通知 + 调用违规次数加一方法)"""
|
"""处理单帧图像数据(检测违规后发送危险通知 + 调用违规次数加一方法)"""
|
||||||
@ -98,31 +98,31 @@ class ClientConnection:
|
|||||||
nparr = np.frombuffer(frame_data, np.uint8)
|
nparr = np.frombuffer(frame_data, np.uint8)
|
||||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||||
if img is None:
|
if img is None:
|
||||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:无法解析图像数据")
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 无法解析图像数据")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 确保图像保存目录存在
|
# 确保图像保存目录存在
|
||||||
os.makedirs('images', exist_ok=True)
|
os.makedirs('images', exist_ok=True)
|
||||||
|
|
||||||
# 保存图像(按IP+时间戳命名,避免冲突)
|
# 保存图像(按IP+时间戳命名、避免冲突)
|
||||||
filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg"
|
filename = f"images/{self.client_ip.replace('.', '_')}_{get_current_time_file_str()}.jpg"
|
||||||
try:
|
try:
|
||||||
cv2.imwrite(filename, img)
|
cv2.imwrite(filename, img)
|
||||||
print(f"[{get_current_time_str()}] 图像已保存至:{filename}")
|
print(f"[{get_current_time_str()}] 图像已保存至: {filename}")
|
||||||
has_violation, data, type = detect(img)
|
has_violation, data, type = detect(img)
|
||||||
print(has_violation)
|
print(has_violation)
|
||||||
print(type)
|
print(type)
|
||||||
print(data)
|
print(data)
|
||||||
if has_violation:
|
if has_violation:
|
||||||
print(
|
print(
|
||||||
f"[{get_current_time_str()}] 客户端{self.client_ip}:检测到违规 - 类型: {type}, 详情: {data}")
|
f"[{get_current_time_str()}] 客户端{self.client_ip}: 检测到违规 - 类型: {type}, 详情: {data}")
|
||||||
|
|
||||||
# 调用违规次数加一方法
|
# 调用违规次数加一方法
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
|
await asyncio.to_thread(increment_alarm_count_by_ip, self.client_ip)
|
||||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数已+1")
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数已+1")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:违规次数更新失败 - {str(e)}")
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 违规次数更新失败 - {str(e)}")
|
||||||
|
|
||||||
# 发送「危险通知」
|
# 发送「危险通知」
|
||||||
danger_msg = {
|
danger_msg = {
|
||||||
@ -132,9 +132,9 @@ class ClientConnection:
|
|||||||
}
|
}
|
||||||
await self.websocket.send_json(danger_msg)
|
await self.websocket.send_json(danger_msg)
|
||||||
else:
|
else:
|
||||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:未检测到违规")
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 未检测到违规")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{get_current_time_str()}] 客户端{self.client_ip}:图像处理错误 - {str(e)}")
|
print(f"[{get_current_time_str()}] 客户端{self.client_ip}: 图像处理错误 - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# 全局状态管理
|
# 全局状态管理
|
||||||
@ -149,7 +149,7 @@ async def heartbeat_checker():
|
|||||||
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
|
timeout_ips = [ip for ip, conn in connected_clients.items() if not conn.is_alive()]
|
||||||
|
|
||||||
if timeout_ips:
|
if timeout_ips:
|
||||||
print(f"[{current_time}] 心跳检查:{len(timeout_ips)}个客户端超时(IP:{timeout_ips})")
|
print(f"[{current_time}] 心跳检查: {len(timeout_ips)}个客户端超时(IP: {timeout_ips})")
|
||||||
for ip in timeout_ips:
|
for ip in timeout_ips:
|
||||||
try:
|
try:
|
||||||
conn = connected_clients[ip]
|
conn = connected_clients[ip]
|
||||||
@ -162,13 +162,13 @@ async def heartbeat_checker():
|
|||||||
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
|
await asyncio.to_thread(update_online_status_by_ip, ip, 0)
|
||||||
action_data = DeviceActionCreate(client_ip=ip, action=0)
|
action_data = DeviceActionCreate(client_ip=ip, action=0)
|
||||||
await asyncio.to_thread(add_device_action, action_data)
|
await asyncio.to_thread(add_device_action, action_data)
|
||||||
print(f"[{current_time}] 客户端{ip}:已标记为离线并记录操作")
|
print(f"[{current_time}] 客户端{ip}: 已标记为离线并记录操作")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{current_time}] 客户端{ip}:离线状态更新失败 - {str(e)}")
|
print(f"[{current_time}] 客户端{ip}: 离线状态更新失败 - {str(e)}")
|
||||||
finally:
|
finally:
|
||||||
connected_clients.pop(ip, None)
|
connected_clients.pop(ip, None)
|
||||||
else:
|
else:
|
||||||
print(f"[{current_time}] 心跳检查:{len(connected_clients)}个客户端在线")
|
print(f"[{current_time}] 心跳检查: {len(connected_clients)}个客户端在线")
|
||||||
|
|
||||||
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||||
|
|
||||||
@ -178,7 +178,7 @@ async def heartbeat_checker():
|
|||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
global heartbeat_task
|
global heartbeat_task
|
||||||
heartbeat_task = asyncio.create_task(heartbeat_checker())
|
heartbeat_task = asyncio.create_task(heartbeat_checker())
|
||||||
print(f"[{get_current_time_str()}] 全局心跳检查任务启动(任务ID:{id(heartbeat_task)})")
|
print(f"[{get_current_time_str()}] 全局心跳检查任务启动(任务ID: {id(heartbeat_task)})")
|
||||||
yield
|
yield
|
||||||
if heartbeat_task and not heartbeat_task.done():
|
if heartbeat_task and not heartbeat_task.done():
|
||||||
heartbeat_task.cancel()
|
heartbeat_task.cancel()
|
||||||
@ -198,11 +198,11 @@ async def send_heartbeat_ack(conn: ClientConnection):
|
|||||||
"client_ip": conn.client_ip
|
"client_ip": conn.client_ip
|
||||||
}
|
}
|
||||||
await conn.websocket.send_json(heartbeat_ack_msg)
|
await conn.websocket.send_json(heartbeat_ack_msg)
|
||||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:已发送心跳确认")
|
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 已发送心跳确认")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
connected_clients.pop(conn.client_ip, None)
|
connected_clients.pop(conn.client_ip, None)
|
||||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:心跳确认发送失败 - {str(e)}")
|
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 心跳确认发送失败 - {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -213,17 +213,17 @@ async def handle_text_msg(conn: ClientConnection, text: str):
|
|||||||
conn.update_heartbeat()
|
conn.update_heartbeat()
|
||||||
await send_heartbeat_ack(conn)
|
await send_heartbeat_ack(conn)
|
||||||
else:
|
else:
|
||||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:未知文本消息类型({msg.get('type')})")
|
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 未知文本消息类型({msg.get('type')})")
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:无效JSON文本消息")
|
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 无效JSON文本消息")
|
||||||
|
|
||||||
|
|
||||||
async def handle_binary_msg(conn: ClientConnection, data: bytes):
|
async def handle_binary_msg(conn: ClientConnection, data: bytes):
|
||||||
try:
|
try:
|
||||||
conn.frame_queue.put_nowait(data)
|
conn.frame_queue.put_nowait(data)
|
||||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:图像数据({len(data)}字节)已加入队列")
|
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 图像数据({len(data)}字节)已加入队列")
|
||||||
except asyncio.QueueFull:
|
except asyncio.QueueFull:
|
||||||
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}:帧队列已满,丢弃当前图像数据")
|
print(f"[{get_current_time_str()}] 客户端{conn.client_ip}: 帧队列已满、丢弃当前图像数据")
|
||||||
|
|
||||||
|
|
||||||
# WebSocket路由配置
|
# WebSocket路由配置
|
||||||
@ -237,7 +237,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
client_ip = websocket.client.host if websocket.client else "unknown_ip"
|
client_ip = websocket.client.host if websocket.client else "unknown_ip"
|
||||||
current_time = get_current_time_str()
|
current_time = get_current_time_str()
|
||||||
print(f"[{current_time}] 客户端{client_ip}:WebSocket连接已建立")
|
print(f"[{current_time}] 客户端{client_ip}: WebSocket连接已建立")
|
||||||
|
|
||||||
is_online_updated = False
|
is_online_updated = False
|
||||||
|
|
||||||
@ -249,13 +249,13 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
old_conn.consumer_task.cancel()
|
old_conn.consumer_task.cancel()
|
||||||
await old_conn.websocket.close(code=1008, reason="同一IP新连接建立")
|
await old_conn.websocket.close(code=1008, reason="同一IP新连接建立")
|
||||||
connected_clients.pop(client_ip)
|
connected_clients.pop(client_ip)
|
||||||
print(f"[{current_time}] 客户端{client_ip}:已关闭旧连接")
|
print(f"[{current_time}] 客户端{client_ip}: 已关闭旧连接")
|
||||||
|
|
||||||
# 注册新连接
|
# 注册新连接
|
||||||
new_conn = ClientConnection(websocket, client_ip)
|
new_conn = ClientConnection(websocket, client_ip)
|
||||||
connected_clients[client_ip] = new_conn
|
connected_clients[client_ip] = new_conn
|
||||||
new_conn.start_consumer()
|
new_conn.start_consumer()
|
||||||
# 初始许可:连接建立后立即发一次,让客户端知道可发第一帧(后续靠取帧后自动发)
|
# 初始许可: 连接建立后立即发一次、让客户端知道可发第一帧(后续靠取帧后自动发)
|
||||||
await new_conn.send_frame_permit()
|
await new_conn.send_frame_permit()
|
||||||
|
|
||||||
# 标记上线并记录
|
# 标记上线并记录
|
||||||
@ -263,12 +263,12 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
|
await asyncio.to_thread(update_online_status_by_ip, client_ip, 1)
|
||||||
action_data = DeviceActionCreate(client_ip=client_ip, action=1)
|
action_data = DeviceActionCreate(client_ip=client_ip, action=1)
|
||||||
await asyncio.to_thread(add_device_action, action_data)
|
await asyncio.to_thread(add_device_action, action_data)
|
||||||
print(f"[{current_time}] 客户端{client_ip}:已标记为在线并记录操作")
|
print(f"[{current_time}] 客户端{client_ip}: 已标记为在线并记录操作")
|
||||||
is_online_updated = True
|
is_online_updated = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{current_time}] 客户端{client_ip}:上线状态更新失败 - {str(e)}")
|
print(f"[{current_time}] 客户端{client_ip}: 上线状态更新失败 - {str(e)}")
|
||||||
|
|
||||||
print(f"[{current_time}] 客户端{client_ip}:新连接注册成功,在线数:{len(connected_clients)}")
|
print(f"[{current_time}] 客户端{client_ip}: 新连接注册成功、在线数: {len(connected_clients)}")
|
||||||
|
|
||||||
# 消息循环
|
# 消息循环
|
||||||
while True:
|
while True:
|
||||||
@ -279,9 +279,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
await handle_binary_msg(new_conn, data["bytes"])
|
await handle_binary_msg(new_conn, data["bytes"])
|
||||||
|
|
||||||
except WebSocketDisconnect as e:
|
except WebSocketDisconnect as e:
|
||||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:主动断开连接(代码:{e.code})")
|
print(f"[{get_current_time_str()}] 客户端{client_ip}: 主动断开连接(代码: {e.code})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:连接异常 - {str(e)[:50]}")
|
print(f"[{get_current_time_str()}] 客户端{client_ip}: 连接异常 - {str(e)[:50]}")
|
||||||
finally:
|
finally:
|
||||||
# 清理资源并标记离线
|
# 清理资源并标记离线
|
||||||
if client_ip in connected_clients:
|
if client_ip in connected_clients:
|
||||||
@ -295,9 +295,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
|
await asyncio.to_thread(update_online_status_by_ip, client_ip, 0)
|
||||||
action_data = DeviceActionCreate(client_ip=client_ip, action=0)
|
action_data = DeviceActionCreate(client_ip=client_ip, action=0)
|
||||||
await asyncio.to_thread(add_device_action, action_data)
|
await asyncio.to_thread(add_device_action, action_data)
|
||||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后已标记为离线")
|
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后已标记为离线")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:断开后离线更新失败 - {str(e)}")
|
print(f"[{get_current_time_str()}] 客户端{client_ip}: 断开后离线更新失败 - {str(e)}")
|
||||||
|
|
||||||
connected_clients.pop(client_ip, None)
|
connected_clients.pop(client_ip, None)
|
||||||
print(f"[{get_current_time_str()}] 客户端{client_ip}:资源已清理,在线数:{len(connected_clients)}")
|
print(f"[{get_current_time_str()}] 客户端{client_ip}: 资源已清理、在线数: {len(connected_clients)}")
|
||||||
|
Reference in New Issue
Block a user