目前可以成功动态更换模型运行的

This commit is contained in:
2025-09-12 14:05:09 +08:00
parent 435b2a0e6c
commit 4be7f7bf14
13 changed files with 1518 additions and 325 deletions

283
app.py Normal file
View File

@ -0,0 +1,283 @@
from flask import Flask, send_from_directory, abort, request
import os
import logging
from functools import wraps
from pathlib import Path
# 跨域依赖必须安装pip install flask-cors
from flask_cors import CORS
# 配置日志(保持原有格式)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 初始化 Flask 应用(供 main.py 导入)
app = Flask(__name__)
# ------------------------------
# 核心修改:与 FastAPI 对齐的跨域配置
# ------------------------------
# 1. 允许的前端域名(完全复制 FastAPI 的 ALLOWED_ORIGINS确保前后端一致
ALLOWED_ORIGINS = [
# "http://localhost:8080", # 本地前端开发地址(必改:替换为你的前端实际地址)
# "http://127.0.0.1:8080",
# "http://服务器IP:8080", # 部署后前端地址替换为你的服务器IP/域名)
# # "*" 仅开发环境临时使用,生产环境必须删除(安全风险)
"*"
]
# 2. 配置 CORS与 FastAPI 规则完全对齐)
CORS(
app,
resources={
r"/*": { # 对所有 Flask 路由生效(覆盖图片、模型下载所有接口)
"origins": ALLOWED_ORIGINS, # 允许的前端域名(与 FastAPI 一致)
"allow_credentials": True, # 允许携带 Cookie与 FastAPI 一致,需登录态必开)
"methods": ["*"], # 允许所有 HTTP 方法FastAPI 用 "*",此处同步)
"allow_headers": ["*"], # 允许所有请求头(与 FastAPI 一致)
}
},
)
# ------------------------------
# 核心路径配置(不变,确保资源目录正确)
# ------------------------------
CURRENT_FILE_PATH = Path(__file__).resolve()
PROJECT_ROOT = CURRENT_FILE_PATH.parent # 项目根目录video/
# 资源目录(图片、模型)
BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve()) # 检测图片目录
BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve()) # 人脸图片目录
BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve()) # 模型文件目录
# 打印路径配置(调试用,确认目录正确)
logger.info(f"[Flask 配置] 项目根目录:{PROJECT_ROOT}")
logger.info(f"[Flask 配置] 模型目录:{BASE_MODEL_DIR}")
logger.info(f"[Flask 配置] 人脸图片目录:{BASE_IMAGE_DIR_UP_IMAGES}")
logger.info(f"[Flask 配置] 检测图片目录:{BASE_IMAGE_DIR_DECT}")
# ------------------------------
# 安全检查装饰器(不变,防路径遍历/非法文件)
# ------------------------------
def safe_path_check(root_dir: str):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
resource_path = kwargs.get('resource_path', '').strip()
# 统一路径分隔符(兼容 Windows \ 和 Linux /
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
# 拼接完整路径(防止路径遍历)
full_file_path = os.path.abspath(os.path.join(root_dir, resource_path))
logger.debug(
f"[Flask 安全检查] 请求路径:{resource_path} | 完整路径:{full_file_path} | 根目录:{root_dir}"
)
# 1. 禁止路径遍历(确保请求文件在根目录内)
if not full_file_path.startswith(root_dir):
logger.warning(
f"[Flask 安全拦截] 非法路径遍历IP{request.remote_addr} | 请求路径:{resource_path}"
)
abort(403)
# 2. 检查文件存在且为有效文件(非目录)
if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path):
logger.warning(
f"[Flask 资源错误] 文件不存在/非文件IP{request.remote_addr} | 路径:{full_file_path}"
)
abort(404)
# 3. 限制文件大小模型200MB图片10MB避免超大文件攻击
max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024
if os.path.getsize(full_file_path) > max_size:
logger.warning(
f"[Flask 大小超限] 文件超过{max_size//1024//1024}MBIP{request.remote_addr} | 路径:{full_file_path}"
)
abort(413)
# 安全检查通过,传递根目录给视图函数
return func(*args, **kwargs, root_dir=root_dir)
return wrapper
return decorator
# ------------------------------
# 1. 模型下载接口(/model/download/*
# ------------------------------
@app.route('/model/download/<path:resource_path>')
@safe_path_check(root_dir=BASE_MODEL_DIR)
def download_model(resource_path, root_dir):
try:
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
dir_path, file_name = os.path.split(resource_path)
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
# 仅允许 .pt 格式YOLO 模型)
if not file_name.lower().endswith('.pt'):
logger.warning(
f"[Flask 格式错误] 非 .pt 模型文件IP{request.remote_addr} | 文件名:{file_name}"
)
abort(415)
logger.info(
f"[Flask 模型下载] 成功请求IP{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
)
# 强制浏览器下载(而非预览),设置二进制文件类型
return send_from_directory(
full_dir,
file_name,
as_attachment=True,
mimetype="application/octet-stream"
)
except Exception as e:
logger.error(
f"[Flask 模型下载异常] IP{request.remote_addr} | 错误:{str(e)}"
)
abort(500)
# ------------------------------
# 2. 人脸图片访问接口(/up_images/*
# ------------------------------
@app.route('/up_images/<path:resource_path>')
@safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES)
def get_face_image(resource_path, root_dir):
try:
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
dir_path, file_name = os.path.split(resource_path)
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
# 仅允许常见图片格式
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
if not file_name.lower().endswith(allowed_ext):
logger.warning(
f"[Flask 格式错误] 非图片文件IP{request.remote_addr} | 文件名:{file_name}"
)
abort(415)
logger.info(
f"[Flask 人脸图片] 成功请求IP{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
)
# 允许浏览器预览图片(而非下载)
return send_from_directory(full_dir, file_name, as_attachment=False)
except Exception as e:
logger.error(
f"[Flask 人脸图片异常] IP{request.remote_addr} | 错误:{str(e)}"
)
abort(500)
# ------------------------------
# 3. 检测图片访问接口(/resource/dect/*
# ------------------------------
@app.route('/resource/dect/<path:resource_path>')
@safe_path_check(root_dir=BASE_IMAGE_DIR_DECT)
def get_dect_image(resource_path, root_dir):
try:
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
dir_path, file_name = os.path.split(resource_path)
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
# 仅允许常见图片格式
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
if not file_name.lower().endswith(allowed_ext):
logger.warning(
f"[Flask 格式错误] 非图片文件IP{request.remote_addr} | 文件名:{file_name}"
)
abort(415)
logger.info(
f"[Flask 检测图片] 成功请求IP{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
)
return send_from_directory(full_dir, file_name, as_attachment=False)
except Exception as e:
logger.error(
f"[Flask 检测图片异常] IP{request.remote_addr} | 错误:{str(e)}"
)
abort(500)
# ------------------------------
# 4. 兼容旧图片接口(/images/* → 映射到 /resource/dect/*
# ------------------------------
@app.route('/images/<path:resource_path>')
@safe_path_check(root_dir=BASE_IMAGE_DIR_DECT)
def get_compatible_image(resource_path, root_dir):
try:
# 逻辑与检测图片接口一致仅URL前缀不同兼容旧前端
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
dir_path, file_name = os.path.split(resource_path)
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
if not file_name.lower().endswith(allowed_ext):
logger.warning(
f"[Flask 格式错误] 非图片文件IP{request.remote_addr} | 文件名:{file_name}"
)
abort(415)
logger.info(
f"[Flask 兼容图片] 成功请求IP{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
)
return send_from_directory(full_dir, file_name, as_attachment=False)
except Exception as e:
logger.error(
f"[Flask 兼容图片异常] IP{request.remote_addr} | 错误:{str(e)}"
)
abort(500)
# ------------------------------
# 全局错误处理器(友好提示,与 FastAPI 错误信息风格一致)
# ------------------------------
@app.errorhandler(403)
def forbidden_error(error):
return "❌ 禁止访问:路径非法(可能存在路径遍历)或无权限", 403
@app.errorhandler(404)
def not_found_error(error):
return "❌ 资源不存在请检查URL路径IP、目录、文件名是否正确", 404
@app.errorhandler(413)
def too_large_error(error):
return "❌ 文件过大图片最大10MB模型最大200MB", 413
@app.errorhandler(415)
def unsupported_type_error(error):
return "❌ 不支持的文件类型图片支持png/jpg/jpeg/gif/bmp模型仅支持pt", 415
@app.errorhandler(500)
def server_error(error):
return "❌ 服务器内部错误:请联系管理员查看后台日志", 500
# ------------------------------
# Flask 独立启动入口(供测试,实际由 main.py 子线程启动)
# ------------------------------
if __name__ == '__main__':
# 确保所有资源目录存在(防止初始化失败)
required_dirs = [
(BASE_IMAGE_DIR_DECT, "检测图片目录"),
(BASE_IMAGE_DIR_UP_IMAGES, "人脸图片目录"),
(BASE_MODEL_DIR, "模型文件目录")
]
for dir_path, dir_desc in required_dirs:
if not os.path.exists(dir_path):
logger.info(f"[Flask 初始化] {dir_desc}不存在,创建:{dir_path}")
os.makedirs(dir_path, exist_ok=True)
# 启动提示(含访问示例)
logger.info("\n[Flask 服务启动成功!] 支持的接口:")
logger.info(f"1. 模型下载 → http://服务器IP:5000/model/download/resource/models/xxx.pt")
logger.info(f"2. 人脸图片 → http://服务器IP:5000/up_images/xxx.jpg")
logger.info(f"3. 检测图片 → http://服务器IP:5000/resource/dect/xxx.jpg 或 http://服务器IP:5000/images/xxx.jpg\n")
# 启动服务(禁用 debug 和自动重载,避免多线程冲突)
app.run(
host="0.0.0.0", # 允许外部IP访问
port=5000, # 与 main.py 中 Flask 端口一致
debug=False,
use_reloader=False
)

View File

@ -57,50 +57,39 @@ def save_db(model_type, client_ip, result):
# 修正后的 detect 函数关键部分
def detect(client_ip, frame): def detect(client_ip, frame):
""" # 1. YOLO检测
执行模型检测,检测到违规时按指定格式保存图片
参数:
frame: 待检测的图像帧OpenCV格式numpy.ndarray类型
返回:
(检测结果布尔值, 检测详情, 检测模型类型)
"""
# 1. YOLO检测优先级1
yolo_flag, yolo_result = yoloDetect(frame) yolo_flag, yolo_result = yoloDetect(frame)
print(f"YOLO检测结果{yolo_result}")
if yolo_flag: if yolo_flag:
# model_type 传入 "yolo"(正确)
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip) full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
if full_save_path: # 只判断完整路径是否有效(用于保存) if full_save_path:
cv2.imwrite(full_save_path, frame) cv2.imwrite(full_save_path, frame)
# 打印时使用「显示用短路径」,符合需求格式 print(f"✅ yolo违规图片已保存{display_path}") # 日志也修正
print(f"✅ YOLO违规图片已保存{display_path}")
save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path)) save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path))
return (True, yolo_result, "yolo") return (True, yolo_result, "yolo")
#
# # 2. 人脸检测优先级2 # 2. 人脸检测
face_flag, face_result = faceDetect(frame) face_flag, face_result = faceDetect(frame)
print(f"人脸检测结果:{face_result}")
if face_flag: if face_flag:
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip) full_save_path, display_path = get_image_save_path(model_type="face", client_ip=client_ip) # 这里改了
if full_save_path: # 只判断完整路径是否有效(用于保存) if full_save_path:
cv2.imwrite(full_save_path, frame) cv2.imwrite(full_save_path, frame)
# 打印时使用「显示用短路径」,符合需求格式 print(f"✅ face违规图片已保存{display_path}") # 日志也修正
print(f"✅ face违规图片已保存{display_path}")
save_db(model_type="face", client_ip=client_ip, result=str(full_save_path)) save_db(model_type="face", client_ip=client_ip, result=str(full_save_path))
return (True, face_result, "face") return (True, face_result, "face")
# 3. OCR检测优先级3 # 3. OCR检测
ocr_flag, ocr_result = ocrDetect(frame) ocr_flag, ocr_result = ocrDetect(frame)
print(f"OCR检测结果{ocr_result}")
if ocr_flag: if ocr_flag:
# 解构元组,保存用完整路径,打印用短路径 full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip) # 这里改了
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip) if full_save_path:
if full_save_path: # 只判断完整路径是否有效(用于保存)
cv2.imwrite(full_save_path, frame) cv2.imwrite(full_save_path, frame)
# 打印时使用「显示用短路径」,符合需求格式 print(f"✅ ocr违规图片已保存{display_path}") # 日志也修正
print(f"✅ ocr违规图片已保存{display_path}")
save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path)) save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path))
return (True, ocr_result, "ocr") return (True, ocr_result, "ocr")
# 4. 无违规内容(不保存图片) # 4. 无违规内容(不保存图片)
print(f"❌ 未检测到任何违规内容,不保存图片") print(f"❌ 未检测到任何违规内容,不保存图片")
return (False, "未检测到任何内容", "none") return (False, "未检测到任何内容", "none")

View File

@ -1,37 +1,43 @@
import os import os
import numpy as np
from ultralytics import YOLO from ultralytics import YOLO
from service.model_service import get_current_yolo_model # 从模型管理模块获取模型
# 全局变量 # 全局模型变量
_yolo_model = None _yolo_model = None
model_path = os.path.join(os.path.dirname(__file__), "models", "best.pt") def load_model(model_path=None):
"""加载YOLO模型优先使用模型管理模块的默认模型"""
def load_model():
"""加载YOLO目标检测模型"""
global _yolo_model global _yolo_model
if model_path is None:
_yolo_model = get_current_yolo_model()
return _yolo_model is not None
try: try:
_yolo_model = YOLO(model_path) _yolo_model = YOLO(model_path)
return True
except Exception as e: except Exception as e:
print(f"YOLO model load failed: {e}") print(f"YOLO模型加载失败(指定路径):{str(e)}")
return False return False
return True if _yolo_model else False
def detect(frame, conf_threshold=0.2): def detect(frame, conf_threshold=0.2):
"""YOLO目标检测返回(是否识别到, 结果字符串)""" """执行目标检测返回(是否成功, 结果字符串)"""
global _yolo_model global _yolo_model
if not _yolo_model or frame is None: # 确保模型已加载
return (False, "未初始化或无效帧") if not _yolo_model:
if not load_model():
return (False, "模型未初始化")
if frame is None:
return (False, "无效输入帧")
try: try:
results = _yolo_model(frame, conf=conf_threshold) # 执行检测frame应为numpy数组
# 检查是否有检测结果 results = _yolo_model(frame, conf=conf_threshold, verbose=False)
has_results = len(results[0].boxes) > 0 if results else False has_results = len(results[0].boxes) > 0 if results else False
if not has_results: if not has_results:
@ -42,13 +48,12 @@ def detect(frame, conf_threshold=0.2):
for box in results[0].boxes: for box in results[0].boxes:
cls = int(box.cls[0]) cls = int(box.cls[0])
conf = float(box.conf[0]) conf = float(box.conf[0])
bbox = [float(x) for x in box.xyxy[0]] bbox = [round(x, 2) for x in box.xyxy[0].tolist()] # 保留两位小数
class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}" class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}"
result_parts.append(f"{class_name} (置信度: {conf:.2f}, 边界框: {bbox})") result_parts.append(f"{class_name}置信度{conf:.2f},位置:{bbox}")
result_str = "; ".join(result_parts) return (True, "; ".join(result_parts))
return (has_results, result_str)
except Exception as e: except Exception as e:
print(f"YOLO detect error: {e}") print(f"检测过程出错:{str(e)}")
return (False, f"检测错误: {str(e)}") return (False, f"检测错误{str(e)}")

128
main.py
View File

@ -1,9 +1,17 @@
from PIL import Image # 正确导入
import numpy as np
import uvicorn
from PIL import Image from PIL import Image
import numpy as np
import uvicorn
import threading
import time
import os
from fastapi import FastAPI from fastapi import FastAPI
# 新增:导入 CORS 相关依赖
from fastapi.middleware.cors import CORSMiddleware
# 导入 Flask 服务实例
from app import app as flask_app
# 原有业务导入
from core.all import load_model, detect from core.all import load_model, detect
from ds.config import SERVER_CONFIG from ds.config import SERVER_CONFIG
from middle.error_handler import global_exception_handler from middle.error_handler import global_exception_handler
@ -11,52 +19,124 @@ from service.user_service import router as user_router
from service.sensitive_service import router as sensitive_router from service.sensitive_service import router as sensitive_router
from service.face_service import router as face_router from service.face_service import router as face_router
from service.device_service import router as device_router from service.device_service import router as device_router
from service.model_service import router as model_router # 模型管理路由
from ws.ws import ws_router, lifespan from ws.ws import ws_router, lifespan
from core.establish import create_directory_structure from core.establish import create_directory_structure
# ------------------------------
# 初始化 FastAPI 应用、指定生命周期管理 # Flask 服务启动函数(不变)
# ------------------------------ def start_flask_service():
try:
print(f"\n[Flask 服务] 准备启动端口5000")
print(f"[Flask 服务] 访问示例http://服务器IP:5000/resource/dect/ocr/xxx.jpg\n")
BASE_IMAGE_DIR = os.path.abspath(os.path.join("resource", "dect"))
if not os.path.exists(BASE_IMAGE_DIR):
print(f"[Flask 服务] 图片根目录不存在,创建:{BASE_IMAGE_DIR}")
os.makedirs(BASE_IMAGE_DIR, exist_ok=True)
flask_app.run(
host="0.0.0.0",
port=5000,
debug=False,
use_reloader=False
)
except Exception as e:
print(f"[Flask 服务] 启动失败:{str(e)}")
# 初始化 FastAPI 应用(新增 CORS 配置)
app = FastAPI( app = FastAPI(
title="内容安全审核后台", title="内容安全审核后台",
description="内容安全审核后台", description="含图片访问服务和动态模型管理",
version="1.0.0", version="1.0.0",
lifespan=lifespan lifespan=lifespan
) )
# ------------------------------ # ------------------------------
# 注册路由 # 新增:完整 CORS 配置(解决跨域问题)
# ------------------------------ # ------------------------------
# 1. 允许的前端域名(根据实际情况修改!本地开发通常是 http://localhost:8080 等)
ALLOWED_ORIGINS = [
# "http://localhost:8080", # 前端本地开发地址(必改,填实际前端地址)
# "http://127.0.0.1:8080",
# "http://服务器IP:8080", # 部署后前端地址(如适用)
"*" #表示允许所有域名(开发环境可用,生产环境不推荐)
]
# 2. 配置 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS, # 允许的前端域名
allow_credentials=True, # 允许携带 Cookie如需登录态则必开
allow_methods=["*"], # 允许所有 HTTP 方法(包括 PUT/DELETE
allow_headers=["*"], # 允许所有请求头(包括 Content-Type
)
# 注册路由(不变)
app.include_router(user_router) app.include_router(user_router)
app.include_router(device_router) app.include_router(device_router)
app.include_router(face_router) app.include_router(face_router)
app.include_router(sensitive_router) app.include_router(sensitive_router)
app.include_router(model_router) # 模型管理路由
app.include_router(ws_router) app.include_router(ws_router)
# ------------------------------ # 注册全局异常处理器(不变)
# 注册全局异常处理器
# ------------------------------
app.add_exception_handler(Exception, global_exception_handler) app.add_exception_handler(Exception, global_exception_handler)
# ------------------------------ # 主服务启动入口(不变)
# 启动服务
# ------------------------------
if __name__ == "__main__": if __name__ == "__main__":
# -------------------------- 配置调整 -------------------------- # 1. 初始化资源
# 模型配置路径(建议改为环境变量)
YOLO_MODEL_PATH = r"/core/models\best.pt"
OCR_CONFIG_PATH = r"/core/config\config.yaml"
create_directory_structure() create_directory_structure()
print(f"[初始化] 目录结构创建完成")
# 初始化项目默认端口设为8000、避免初始化失败时port未定义 # 创建模型保存目录
MODEL_SAVE_DIR = os.path.join("core", "models")
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
print(f"[初始化] 模型保存目录:{MODEL_SAVE_DIR}")
# # 模型路径配置
# YOLO_MODEL_PATH = os.path.join("core", "models", "best.pt")
# OCR_CONFIG_PATH = os.path.join("core", "config", "config.yaml")
# print(f"[初始化] 默认YOLO模型路径{YOLO_MODEL_PATH}")
# print(f"[初始化] OCR 配置路径:{OCR_CONFIG_PATH}")
# 加载检测模型
try:
load_success = load_model()
if load_success:
print(f"[初始化] 检测模型加载完成")
else:
print(f"[初始化] 未找到默认模型可通过API上传并设置默认模型")
except Exception as e:
print(f"[初始化] 模型加载警告:{str(e)}(服务仍启动)")
# 2. 启动 Flask 服务(子线程)
flask_thread = threading.Thread(
target=start_flask_service,
daemon=True
)
flask_thread.start()
# 等待 Flask 初始化
time.sleep(1)
if flask_thread.is_alive():
print(f"[Flask 服务] 启动成功(运行中)")
else:
print(f"[Flask 服务] 启动失败!图片访问不可用")
# 3. 启动 FastAPI 主服务
port = int(SERVER_CONFIG.get("port", 8000)) port = int(SERVER_CONFIG.get("port", 8000))
print(f"\n[FastAPI 服务] 准备启动,端口:{port}")
print(f"[FastAPI 服务] 接口文档http://服务器IP:{port}/docs\n")
# 启动 UVicorn 服务
uvicorn.run( uvicorn.run(
app="main:app", app="main:app",
host="0.0.0.0", host="0.0.0.0",
port=port, port=port,
workers=8, workers=1,
ws="websockets" ws="websockets",
reload=False
) )

View File

@ -1,30 +1,41 @@
from datetime import datetime from datetime import datetime
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Optional
# ------------------------------ # ------------------------------
# 请求模型(前端传参校验) # 请求模型(前端传参校验)- 保留update的eigenvalue如需更新特征值
# ------------------------------ # ------------------------------
class FaceCreateRequest(BaseModel): class FaceCreateRequest(BaseModel):
"""创建人脸记录请求模型无需ID、由数据库自增""" """创建人脸记录请求模型无需ID、由数据库自增"""
name: str = Field(None, max_length=255, description="名称可选、最长255字符") name: Optional[str] = Field(None, max_length=255, description="名称可选、最长255字符")
class FaceUpdateRequest(BaseModel): class FaceUpdateRequest(BaseModel):
"""更新人脸记录请求模型(不变""" """更新人脸记录请求模型 - 保留eigenvalue如需更新特征值不影响返回"""
name: str = Field(None, max_length=255, description="名称") name: Optional[str] = Field(None, max_length=255, description="名称(可选)")
eigenvalue: str = Field(None, max_length=255, description="特征文件处理后可更新)") eigenvalue: Optional[str] = Field(None, description="特征值(可选,文件处理后可更新)") # 保留更新能力
address: Optional[str] = Field(None, description="图片完整路径(可选,更新图片时使用)")
# ------------------------------ # ------------------------------
# 响应模型(后端返回数据) # 响应模型(后端返回数据)- 核心修改删除eigenvalue字段
# ------------------------------ # ------------------------------
class FaceResponse(BaseModel): class FaceResponse(BaseModel):
"""人脸记录响应模型(仍包含ID、由数据库生成后返回""" """人脸记录响应模型(仅返回需要的字段移除eigenvalue"""
id: int = Field(..., description="主键ID数据库自增") id: int = Field(..., description="主键ID数据库自增")
name: str = Field(None, description="名称") name: Optional[str] = Field(None, description="名称")
eigenvalue: str | None = Field(None, description="特征(可为空)") address: Optional[str] = Field(None, description="人脸图片完整保存路径(数据库新增字段)") # 仅保留address
created_at: datetime = Field(..., description="记录创建时间") created_at: datetime = Field(..., description="记录创建时间(数据库自动生成)")
updated_at: datetime = Field(..., description="记录更新时间") updated_at: datetime = Field(..., description="记录更新时间(数据库自动生成)")
# 关键配置:支持从数据库查询结果(字典)直接转换
model_config = {"from_attributes": True}
class FaceListResponse(BaseModel):
"""人脸列表分页响应模型结构不变内部FaceResponse已移除eigenvalue"""
total: int = Field(..., description="筛选后的总记录数")
faces: List[FaceResponse] = Field(..., description="当前页的人脸记录列表")
model_config = {"from_attributes": True} model_config = {"from_attributes": True}

37
schema/model_schema.py Normal file
View File

@ -0,0 +1,37 @@
from datetime import datetime
from pydantic import BaseModel, Field
from typing import List, Optional
# 请求模型
class ModelCreateRequest(BaseModel):
name: str = Field(..., max_length=255, description="模型名称必填yolo-v8s-car")
description: Optional[str] = Field(None, description="模型描述(可选)")
is_default: Optional[bool] = Field(False, description="是否设为默认模型")
class ModelUpdateRequest(BaseModel):
name: Optional[str] = Field(None, max_length=255, description="模型名称(可选修改)")
description: Optional[str] = Field(None, description="模型描述(可选修改)")
is_default: Optional[bool] = Field(None, description="是否设为默认模型(可选切换)")
# 响应模型
class ModelResponse(BaseModel):
id: int = Field(..., description="模型ID")
name: str = Field(..., description="模型名称")
path: str = Field(..., description="模型文件相对路径")
is_default: bool = Field(..., description="是否默认模型")
description: Optional[str] = Field(None, description="模型描述")
file_size: Optional[int] = Field(None, description="文件大小(字节)")
created_at: datetime = Field(..., description="创建时间")
updated_at: datetime = Field(..., description="更新时间")
model_config = {"from_attributes": True}
class ModelListResponse(BaseModel):
total: int = Field(..., description="总记录数")
models: List[ModelResponse] = Field(..., description="当前页模型列表")
model_config = {"from_attributes": True}

View File

@ -1,6 +1,6 @@
from datetime import datetime from datetime import datetime
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Optional
# ------------------------------ # ------------------------------
@ -30,3 +30,11 @@ class UserResponse(BaseModel):
# Pydantic V2 配置(支持从数据库查询结果转换) # Pydantic V2 配置(支持从数据库查询结果转换)
model_config = {"from_attributes": True} model_config = {"from_attributes": True}
class UserListResponse(BaseModel):
"""用户列表分页响应模型(与设备/人脸列表结构对齐)"""
total: int = Field(..., description="用户总数")
users: List[UserResponse] = Field(..., description="当前页用户列表")
model_config = {"from_attributes": True}

View File

@ -1,162 +1,140 @@
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, Request
from fastapi.responses import FileResponse
from mysql.connector import Error as MySQLError from mysql.connector import Error as MySQLError
import os
from pathlib import Path
from ds.db import db from ds.db import db
from schema.face_schema import FaceCreateRequest, FaceUpdateRequest, FaceResponse from schema.face_schema import (
FaceCreateRequest,
FaceUpdateRequest,
FaceResponse,
FaceListResponse
)
from schema.response_schema import APIResponse from schema.response_schema import APIResponse
from middle.auth_middleware import get_current_user
from schema.user_schema import UserResponse
from util.face_util import add_binary_data, get_average_feature from util.face_util import add_binary_data, get_average_feature
#初始化实例 from util.file_util import save_face_to_up_images
router = APIRouter(
prefix="/faces",
tags=["人脸管理"]
)
router = APIRouter(prefix="/faces", tags=["人脸管理"])
# ------------------------------ # ------------------------------
# 1. 创建人脸记录(核心修正: ID 数据库自增、前端无需传 # 1. 创建人脸记录(使用修复后的路径
# ------------------------------ # ------------------------------
@router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件、ID自增") @router.post("", response_model=APIResponse, summary="创建人脸记录")
async def create_face( async def create_face(
# 前端仅需传: name可选、Form格式、file必传、文件 request: Request,
name: str = Form(None, max_length=255, description="名称(可选)"), name: str = Form(None, max_length=255, description="名称(可选)"),
file: UploadFile = File(..., description="人脸文件(必传、暂不处理内容") file: UploadFile = File(..., description="人脸文件(必传)")
): ):
"""
创建人脸记录:
- 需登录认证
- 前端传参: multipart/form-data 表单name 可选、file 必传)
- ID 由数据库自动生成、无需前端传入
- 暂不处理文件内容、eigenvalue 设为 None
"""
# 调用你的方法
conn = None conn = None
cursor = None cursor = None
try: try:
# 1. 用模型校验 name仅校验长度、无需ID
face_create = FaceCreateRequest(name=name) face_create = FaceCreateRequest(name=name)
client_ip = request.client.host if request.client else ""
if not client_ip:
raise HTTPException(status_code=400, detail="无法获取客户端IP")
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 把文件转为二进制数组 # 读取图片并保存(使用修复后的路径逻辑)
file_content = await file.read() file_content = await file.read()
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else "jpg"
# 计算特征值 save_result = save_face_to_up_images(
flag, eigenvalue = add_binary_data(file_content) client_ip=client_ip,
face_name=name,
if flag == False: image_bytes=file_content,
raise HTTPException( image_format=file_ext
status_code=500,
detail="未检测到人脸"
) )
if not save_result["success"]:
raise HTTPException(status_code=500, detail=f"图片保存失败:{save_result['msg']}")
db_image_path = save_result["db_path"] # 从修复后的方法获取路径
# 打印数组长度 # 提取人脸特征
print(f"文件大小: {len(file_content)} 字节") detect_success, detect_result = add_binary_data(file_content)
if not detect_success:
raise HTTPException(status_code=400, detail=f"人脸检测失败:{detect_result}")
eigenvalue = detect_result
# 2. 插入数据库: 无需传 ID自增、只传 name 和 eigenvalueNone # 插入数据库
insert_query = """ insert_query = """
INSERT INTO face (name, eigenvalue) INSERT INTO face (name, eigenvalue, address)
VALUES (%s, %s) VALUES (%s, %s, %s)
""" """
cursor.execute(insert_query, (face_create.name, str(eigenvalue))) cursor.execute(insert_query, (face_create.name, str(eigenvalue), db_image_path))
conn.commit() conn.commit()
# 3. 获取数据库自动生成的 ID关键: 用 LAST_INSERT_ID() 查刚插入的记录 # 查询新记录
select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()" cursor.execute("""
cursor.execute(select_new_query) SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = LAST_INSERT_ID()
""")
created_face = cursor.fetchone() created_face = cursor.fetchone()
if not created_face: if not created_face:
raise HTTPException( raise HTTPException(status_code=500, detail="创建成功但无法获取记录")
status_code=500,
detail="创建人脸记录成功、但无法获取新创建的记录"
)
return APIResponse( return APIResponse(
code=201, code=201,
message=f"人脸记录创建成功ID: {created_face['id']}、文件名: {file.filename}", message=f"人脸记录创建成功ID: {created_face['id']}",
data=FaceResponse(**created_face) data=FaceResponse(**created_face)
) )
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
# 改为使用HTTPException raise HTTPException(status_code=500, detail=f"创建失败: {str(e)}") from e
raise HTTPException(
status_code=500,
detail=f"创建人脸记录失败: {str(e)}"
) from e
except Exception as e: except Exception as e:
# 捕获其他可能的异常 raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
raise HTTPException(
status_code=500,
detail=f"服务器错误: {str(e)}"
) from e
finally: finally:
await file.close() # 关闭文件流 await file.close()
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# 调用人脸识别得到特征值(这里可以添加你的人脸识别逻辑)
flag, eigenvalue = add_binary_data(file_content)
if flag == False:
raise HTTPException(
status_code=500,
detail="未检测到人脸"
)
# 将 eigenvalue 转为 str
eigenvalue = str(eigenvalue)
# 其他接口(获取单条/列表、更新、删除、获取图片)与之前一致,无需修改
# ------------------------------ # ------------------------------
# 2. 获取单个人脸记录不变、用自增ID查询 # 2. 获取单个人脸记录
# ------------------------------ # ------------------------------
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录") @router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
async def get_face( async def get_face(face_id: int):
face_id: int, # 这里的 ID 是数据库自增的、前端从创建响应中获取
current_user: UserResponse = Depends(get_current_user)
):
conn = None conn = None
cursor = None cursor = None
try: try:
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
query = "SELECT * FROM face WHERE id = %s" query = """
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = %s
"""
cursor.execute(query, (face_id,)) cursor.execute(query, (face_id,))
face = cursor.fetchone() face = cursor.fetchone()
if not face: if not face:
raise HTTPException( raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
status_code=404,
detail=f"ID为 {face_id} 的人脸记录不存在"
)
return APIResponse( return APIResponse(
code=200, code=200,
message="人脸记录查询成功", message="查询成功",
data=FaceResponse(**face) data=FaceResponse(**face)
) )
except MySQLError as e: except MySQLError as e:
# 改为使用HTTPException raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
raise HTTPException(
status_code=500,
detail=f"查询人脸记录失败: {str(e)}"
) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# 后续 3.获取所有、4.更新、5.删除 接口(修复异常处理)
# ------------------------------ # ------------------------------
# 3. 获取所有人脸记录(不变) # 3. 获取人脸列表
# ------------------------------ # ------------------------------
@router.get("", response_model=APIResponse, summary="获取所有人脸记录") @router.get("", response_model=APIResponse, summary="获取人脸列表(分页+筛选)")
async def get_all_faces( async def get_face_list(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: str = Query(None),
has_eigenvalue: bool = Query(None)
): ):
conn = None conn = None
cursor = None cursor = None
@ -164,50 +142,66 @@ async def get_all_faces(
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
query = "SELECT * FROM face ORDER BY id" # 按自增ID排序 where_clause = []
cursor.execute(query) params = []
faces = cursor.fetchall() if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%")
if has_eigenvalue is not None:
where_clause.append("eigenvalue IS NOT NULL" if has_eigenvalue else "eigenvalue IS NULL")
# 总记录数
count_query = "SELECT COUNT(*) AS total FROM face"
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 列表数据
offset = (page - 1) * page_size
list_query = """
SELECT id, name, address, created_at, updated_at
FROM face
"""
if where_clause:
list_query += " WHERE " + " AND ".join(where_clause)
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_query, params)
face_list = cursor.fetchall()
return APIResponse( return APIResponse(
code=200, code=200,
message="所有人脸记录查询成功", message=f"获取成功(共{total}条)",
data=[FaceResponse(** face) for face in faces] data=FaceListResponse(
total=total,
faces=[FaceResponse(**face) for face in face_list]
)
) )
except MySQLError as e: except MySQLError as e:
raise HTTPException( raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
status_code=500,
detail=f"查询所有人脸记录失败: {str(e)}"
) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# ------------------------------ # ------------------------------
# 4. 更新人脸记录不变、用自增ID更新 # 4. 更新人脸记录
# ------------------------------ # ------------------------------
@router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录") @router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录")
async def update_face( async def update_face(face_id: int, face_update: FaceUpdateRequest):
face_id: int,
face_update: FaceUpdateRequest,
current_user: UserResponse = Depends(get_current_user)
):
conn = None conn = None
cursor = None cursor = None
try: try:
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 检查记录是否存在 cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
check_query = "SELECT id FROM face WHERE id = %s" exist_face = cursor.fetchone()
cursor.execute(check_query, (face_id,)) if not exist_face:
existing_face = cursor.fetchone() raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
if not existing_face: old_db_path = exist_face["address"]
raise HTTPException(
status_code=404,
detail=f"ID为 {face_id} 的人脸记录不存在"
)
# 构建更新语句
update_fields = [] update_fields = []
params = [] params = []
if face_update.name is not None: if face_update.name is not None:
@ -216,6 +210,18 @@ async def update_face(
if face_update.eigenvalue is not None: if face_update.eigenvalue is not None:
update_fields.append("eigenvalue = %s") update_fields.append("eigenvalue = %s")
params.append(face_update.eigenvalue) params.append(face_update.eigenvalue)
if face_update.address is not None:
# 删除旧图片(相对路径转绝对路径)
if old_db_path:
old_abs_path = Path(old_db_path).resolve()
if old_abs_path.exists():
try:
old_abs_path.unlink() # 使用Path方法删除更安全
print(f"[FaceRouter] 已删除旧图片:{old_abs_path}")
except Exception as e:
print(f"[FaceRouter] 删除旧图片失败:{str(e)}")
update_fields.append("address = %s")
params.append(face_update.address)
if not update_fields: if not update_fields:
raise HTTPException(status_code=400, detail="至少需提供一个更新字段") raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
@ -225,117 +231,143 @@ async def update_face(
cursor.execute(update_query, params) cursor.execute(update_query, params)
conn.commit() conn.commit()
# 查询更新后记录 cursor.execute("""
select_query = "SELECT * FROM face WHERE id = %s" SELECT id, name, address, created_at, updated_at
cursor.execute(select_query, (face_id,)) FROM face
WHERE id = %s
""", (face_id,))
updated_face = cursor.fetchone() updated_face = cursor.fetchone()
return APIResponse( return APIResponse(
code=200, code=200,
message="人脸记录更新成功", message="更新成功",
data=FaceResponse(**updated_face) data=FaceResponse(**updated_face)
) )
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise HTTPException( raise HTTPException(status_code=500, detail=f"更新失败: {str(e)}") from e
status_code=500,
detail=f"更新人脸记录失败: {str(e)}"
) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# ------------------------------ # ------------------------------
# 5. 删除人脸记录不变、用自增ID删除 # 5. 删除人脸记录
# ------------------------------ # ------------------------------
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录") @router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
async def delete_face( async def delete_face(face_id: int):
face_id: int,
current_user: UserResponse = Depends(get_current_user)
):
conn = None conn = None
cursor = None cursor = None
try: try:
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
check_query = "SELECT id FROM face WHERE id = %s" cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
cursor.execute(check_query, (face_id,)) exist_face = cursor.fetchone()
existing_face = cursor.fetchone() if not exist_face:
if not existing_face: raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
raise HTTPException( old_db_path = exist_face["address"]
status_code=404,
detail=f"ID为 {face_id} 的人脸记录不存在"
)
delete_query = "DELETE FROM face WHERE id = %s" cursor.execute("DELETE FROM face WHERE id = %s", (face_id,))
cursor.execute(delete_query, (face_id,))
conn.commit() conn.commit()
# 删除图片
extra_msg = ""
if old_db_path:
old_abs_path = Path(old_db_path).resolve()
if old_abs_path.exists():
try:
old_abs_path.unlink()
print(f"[FaceRouter] 已删除图片:{old_abs_path}")
extra_msg = "(已同步删除图片)"
except Exception as e:
print(f"[FaceRouter] 删除图片失败:{str(e)}")
extra_msg = "(图片删除失败)"
else:
extra_msg = "(图片不存在)"
else:
extra_msg = "(无关联图片)"
return APIResponse( return APIResponse(
code=200, code=200,
message=f"ID为 {face_id}人脸记录删除成功", message=f"ID为 {face_id} 的记录删除成功 {extra_msg}",
data=None data=None
) )
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise HTTPException( raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}") from e
status_code=500,
detail=f"删除人脸记录失败: {str(e)}"
) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
def get_all_face_name_with_eigenvalue() -> dict: # ------------------------------
""" # 6. 获取人脸图片
获取所有人脸的名称及其对应的特征值、组成字典返回 # ------------------------------
key: 人脸名称name @router.get("/{face_id}/image", summary="获取人脸图片")
value: 人脸特征值eigenvalue、若名称重复则返回平均特征值 async def get_face_image(face_id: int):
注: 过滤掉name为None的记录、避免字典key为None的情况 conn = None
""" cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = "SELECT address, name FROM face WHERE id = %s"
cursor.execute(query, (face_id,))
face = cursor.fetchone()
if not face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
db_path = face["address"]
abs_path = Path(db_path).resolve() # 转为绝对路径
if not db_path or not abs_path.exists():
raise HTTPException(status_code=404, detail=f"图片不存在(路径:{db_path}")
return FileResponse(
path=abs_path,
filename=f"face_{face_id}_{face['name'] or '未命名'}.{db_path.split('.')[-1]}",
media_type=f"image/{db_path.split('.')[-1]}"
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"获取图片失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 内部工具方法
# ------------------------------
def get_all_face_name_with_eigenvalue() -> dict:
conn = None conn = None
cursor = None cursor = None
try: try:
# 1. 建立数据库连接并获取游标dictionary=True使结果以字典形式返回
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 2. 执行SQL查询: 只获取name非空的记录、减少数据传输
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL" query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
cursor.execute(query) cursor.execute(query)
faces = cursor.fetchall() # 返回结果: 列表套字典、如 [{"name":"张三","eigenvalue":...}, ...] faces = cursor.fetchall()
# 3. 收集同一名称对应的所有特征值(处理名称重复场景)
name_to_eigenvalues = {} name_to_eigenvalues = {}
for face in faces: for face in faces:
name = face["name"] name = face["name"]
eigenvalue = face["eigenvalue"] eigenvalue = face["eigenvalue"]
# 若名称已存在、追加特征值;否则新建列表存储
if name in name_to_eigenvalues: if name in name_to_eigenvalues:
name_to_eigenvalues[name].append(eigenvalue) name_to_eigenvalues[name].append(eigenvalue)
else: else:
name_to_eigenvalues[name] = [eigenvalue] name_to_eigenvalues[name] = [eigenvalue]
# 4. 构建最终字典: 重复名称取平均、唯一名称直接取特征值
face_dict = {} face_dict = {}
for name, eigenvalues in name_to_eigenvalues.items(): for name, eigenvalues in name_to_eigenvalues.items():
# 处理特征值: 多个则求平均、单个则直接使用
if len(eigenvalues) > 1: if len(eigenvalues) > 1:
# 调用外部方法计算平均特征值需确保binary_face_feature_handler已正确导入
face_dict[name] = get_average_feature(eigenvalues) face_dict[name] = get_average_feature(eigenvalues)
else: else:
# 取列表中唯一的特征值避免value为列表类型
face_dict[name] = eigenvalues[0] face_dict[name] = eigenvalues[0]
return face_dict return face_dict
except MySQLError as e: except MySQLError as e:
# 捕获数据库异常、添加上下文信息后重新抛出(便于定位问题) raise Exception(f"获取人脸特征失败: {str(e)}") from e
raise Exception(f"获取人脸名称与特征值失败: {str(e)}") from e
finally: finally:
# 5. 无论是否异常、均释放数据库连接和游标(避免资源泄漏)
db.close_connection(conn, cursor) db.close_connection(conn, cursor)

497
service/model_service.py Normal file
View File

@ -0,0 +1,497 @@
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
from fastapi.responses import FileResponse
from mysql.connector import Error as MySQLError
import os
import shutil
from pathlib import Path
from datetime import datetime
# 复用项目依赖
from ds.db import db
from schema.model_schema import (
ModelCreateRequest,
ModelUpdateRequest,
ModelResponse,
ModelListResponse
)
from schema.response_schema import APIResponse
from util.model_util import load_yolo_model # 使用修复后的模型加载工具
# 路径配置
CURRENT_FILE_PATH = Path(__file__).resolve()
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent
MODEL_SAVE_ROOT = PROJECT_ROOT / "resource" / "models"
MODEL_SAVE_ROOT.mkdir(exist_ok=True, parents=True)
DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep
# 模型限制
ALLOWED_MODEL_EXT = {"pt"}
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
# 全局模型变量
global _yolo_model
_yolo_model = None
router = APIRouter(prefix="/models", tags=["模型管理"])
# 工具函数:验证模型路径
def get_valid_model_abs_path(relative_path: str) -> str:
try:
relative_path = relative_path.replace("/", os.sep)
model_abs_path = PROJECT_ROOT / relative_path
model_abs_path = model_abs_path.resolve()
model_abs_path_str = str(model_abs_path)
if not model_abs_path_str.startswith(str(MODEL_SAVE_ROOT)):
raise HTTPException(
status_code=400,
detail=f"模型路径非法!允许目录:{str(MODEL_SAVE_ROOT)},当前路径:{model_abs_path_str}"
)
if not model_abs_path.exists():
raise HTTPException(
status_code=404,
detail=f"模型文件不存在!路径:{model_abs_path_str}"
)
if not model_abs_path.is_file():
raise HTTPException(
status_code=400,
detail=f"路径不是文件!路径:{model_abs_path_str}"
)
file_size = model_abs_path.stat().st_size
if file_size > MAX_MODEL_SIZE:
raise HTTPException(
status_code=400,
detail=f"模型文件过大({file_size // 1024 // 1024}MB超过限制{MAX_MODEL_SIZE // 1024 // 1024}MB"
)
file_ext = model_abs_path.suffix.lower()
if file_ext not in [f".{ext}" for ext in ALLOWED_MODEL_EXT]:
raise HTTPException(
status_code=400,
detail=f"模型格式非法(仅支持{ALLOWED_MODEL_EXT})!当前格式:{file_ext}"
)
print(f"[模型路径校验] 成功!路径:{model_abs_path_str},大小:{file_size // 1024}KB")
return model_abs_path_str
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"路径处理失败:{str(e)}"
) from e
# 1. 上传模型
@router.post("", response_model=APIResponse, summary="上传YOLO模型.pt格式")
async def upload_model(
name: str = Form(..., description="模型名称"),
description: str = Form(None, description="模型描述"),
is_default: bool = Form(False, description="是否设为默认模型"),
file: UploadFile = File(..., description=f"YOLO模型文件.pt最大{MAX_MODEL_SIZE // 1024 // 1024}MB")
):
conn = None
cursor = None
saved_file_path = None
try:
# 校验文件
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else ""
if file_ext not in ALLOWED_MODEL_EXT:
raise HTTPException(
status_code=400,
detail=f"仅支持{ALLOWED_MODEL_EXT}格式,当前:{file_ext}"
)
if file.size > MAX_MODEL_SIZE:
raise HTTPException(
status_code=400,
detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB当前{file.size // 1024 // 1024}MB"
)
# 保存文件
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
safe_filename = f"model_{timestamp}_{file.filename.replace(' ', '_')}"
saved_file_path = MODEL_SAVE_ROOT / safe_filename
with open(saved_file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
saved_file_path.chmod(0o644) # 设置权限
# 数据库路径处理
db_relative_path = str(saved_file_path).replace(DB_PATH_PREFIX_TO_REMOVE, "").replace(os.sep, "/")
# 数据库操作
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
if is_default:
cursor.execute("UPDATE model SET is_default = 0")
insert_sql = """
INSERT INTO model (name, path, is_default, description, file_size)
VALUES (%s, %s, %s, %s, %s)
"""
cursor.execute(insert_sql, (name, db_relative_path, 1 if is_default else 0, description, file.size))
conn.commit()
cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()")
new_model = cursor.fetchone()
if not new_model:
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
# 加载默认模型
global _yolo_model
if is_default:
valid_abs_path = get_valid_model_abs_path(db_relative_path)
_yolo_model = load_yolo_model(valid_abs_path)
if not _yolo_model:
raise HTTPException(
status_code=500,
detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path}"
)
return APIResponse(
code=201,
message=f"模型上传成功ID{new_model['id']}",
data=ModelResponse(**new_model)
)
except MySQLError as e:
if conn:
conn.rollback()
if saved_file_path and saved_file_path.exists():
saved_file_path.unlink()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
except Exception as e:
if saved_file_path and saved_file_path.exists():
saved_file_path.unlink()
raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e
finally:
await file.close()
db.close_connection(conn, cursor)
# 2. 获取模型列表
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
async def get_model_list(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: str = Query(None),
is_default: bool = Query(None)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
where_clause = []
params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%")
if is_default is not None:
where_clause.append("is_default = %s")
params.append(1 if is_default else 0)
# 总记录数
count_sql = "SELECT COUNT(*) AS total FROM model"
if where_clause:
count_sql += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_sql, params)
total = cursor.fetchone()["total"]
# 分页数据
offset = (page - 1) * page_size
list_sql = "SELECT * FROM model"
if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause)
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_sql, params)
model_list = cursor.fetchall()
return APIResponse(
code=200,
message=f"获取成功!共{total}条记录",
data=ModelListResponse(
total=total,
models=[ModelResponse(**model) for model in model_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 3. 获取默认模型
@router.get("/default", response_model=APIResponse, summary="获取当前默认模型")
async def get_default_model():
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE is_default = 1")
default_model = cursor.fetchone()
if not default_model:
raise HTTPException(status_code=404, detail="暂无默认模型")
valid_abs_path = get_valid_model_abs_path(default_model["path"])
global _yolo_model
if not _yolo_model:
_yolo_model = load_yolo_model(valid_abs_path)
if not _yolo_model:
raise HTTPException(
status_code=500,
detail=f"默认模型存在,但加载失败(路径:{valid_abs_path}"
)
return APIResponse(
code=200,
message="默认模型查询成功",
data=ModelResponse(**default_model)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 4. 获取单个模型详情
@router.get("/{model_id}", response_model=APIResponse, summary="获取单个模型详情")
async def get_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
model = cursor.fetchone()
if not model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
try:
model_abs_path = get_valid_model_abs_path(model["path"])
except HTTPException as e:
return APIResponse(
code=200,
message=f"查询成功,但路径异常:{e.detail}",
data=ModelResponse(**model)
)
return APIResponse(
code=200,
message="查询成功",
data=ModelResponse(**model)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 5. 更新模型信息
@router.put("/{model_id}", response_model=APIResponse, summary="更新模型信息")
async def update_model(model_id: int, model_update: ModelUpdateRequest):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
exist_model = cursor.fetchone()
if not exist_model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
update_fields = []
params = []
if model_update.name is not None:
update_fields.append("name = %s")
params.append(model_update.name)
if model_update.description is not None:
update_fields.append("description = %s")
params.append(model_update.description)
need_load_default = False
if model_update.is_default is not None:
if model_update.is_default:
cursor.execute("UPDATE model SET is_default = 0")
update_fields.append("is_default = 1")
need_load_default = True
else:
cursor.execute("SELECT COUNT(*) AS cnt FROM model WHERE is_default = 1")
default_count = cursor.fetchone()["cnt"]
if default_count == 1 and exist_model["is_default"]:
raise HTTPException(
status_code=400,
detail="当前是唯一默认模型,不可取消!"
)
update_fields.append("is_default = 0")
if not update_fields:
raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
params.append(model_id)
update_sql = f"""
UPDATE model
SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP
WHERE id = %s
"""
cursor.execute(update_sql, params)
conn.commit()
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
updated_model = cursor.fetchone()
global _yolo_model
if need_load_default:
valid_abs_path = get_valid_model_abs_path(updated_model["path"])
_yolo_model = load_yolo_model(valid_abs_path)
if not _yolo_model:
raise HTTPException(
status_code=500,
detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path}"
)
return APIResponse(
code=200,
message="模型更新成功",
data=ModelResponse(**updated_model)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 6. 删除模型
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
async def delete_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
exist_model = cursor.fetchone()
if not exist_model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
if exist_model["is_default"]:
raise HTTPException(status_code=400, detail="默认模型不可删除!")
try:
model_abs_path_str = get_valid_model_abs_path(exist_model["path"])
model_abs_path = Path(model_abs_path_str)
except HTTPException as e:
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
conn.commit()
return APIResponse(
code=200,
message=f"记录删除成功,文件异常:{e.detail}",
data=None
)
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
conn.commit()
extra_msg = ""
try:
model_abs_path.unlink()
extra_msg = f"(已删除文件)"
except Exception as e:
extra_msg = f"(文件删除失败:{str(e)}"
global _yolo_model
if _yolo_model and str(_yolo_model.model_path) == model_abs_path_str:
_yolo_model = None
print(f"[模型删除] 已清空全局模型(路径:{model_abs_path_str}")
return APIResponse(
code=200,
message=f"模型删除成功ID{model_id} {extra_msg}",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 7. 下载模型文件
@router.get("/{model_id}/download", summary="下载模型文件")
async def download_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
model = cursor.fetchone()
if not model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
valid_abs_path = get_valid_model_abs_path(model["path"])
model_abs_path = Path(valid_abs_path)
return FileResponse(
path=model_abs_path,
filename=f"model_{model_id}_{model['name']}.pt",
media_type="application/octet-stream"
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 对外提供当前模型
def get_current_yolo_model():
"""供检测模块获取当前加载的模型"""
global _yolo_model
if not _yolo_model:
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT path FROM model WHERE is_default = 1")
default_model = cursor.fetchone()
if not default_model:
print("[get_current_yolo_model] 暂无默认模型")
return None
valid_abs_path = get_valid_model_abs_path(default_model["path"])
_yolo_model = load_yolo_model(valid_abs_path)
if _yolo_model:
print(f"[get_current_yolo_model] 自动加载默认模型成功")
else:
print(f"[get_current_yolo_model] 自动加载默认模型失败")
except Exception as e:
print(f"[get_current_yolo_model] 加载失败:{str(e)}")
finally:
db.close_connection(conn, cursor)
return _yolo_model

View File

@ -1,6 +1,7 @@
from datetime import timedelta from datetime import timedelta
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException, Query
from mysql.connector import Error as MySQLError from mysql.connector import Error as MySQLError
from ds.db import db from ds.db import db
@ -11,7 +12,7 @@ from middle.auth_middleware import (
verify_password, verify_password,
create_access_token, create_access_token,
ACCESS_TOKEN_EXPIRE_MINUTES, ACCESS_TOKEN_EXPIRE_MINUTES,
get_current_user get_current_user # 仅保留登录用户校验移除is_admin导入
) )
# 创建用户接口路由(前缀 /users、标签用于 Swagger 分类) # 创建用户接口路由(前缀 /users、标签用于 Swagger 分类)
@ -152,3 +153,98 @@ async def get_current_user_info(
data=current_user data=current_user
) )
# ------------------------------
# 4. 获取用户列表(仅需登录权限)
# ------------------------------
@router.get("/list", response_model=APIResponse, summary="获取用户列表")
async def get_user_list(
page: int = Query(1, ge=1, description="页码从1开始"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100之间"),
username: Optional[str] = Query(None, description="用户名模糊搜索"),
current_user: UserResponse = Depends(get_current_user) # 仅需登录即可访问(移除管理员校验)
):
"""
获取用户列表:
- 需登录权限(请求头携带 Token: Bearer <token>
- 支持分页查询page=页码page_size=每页条数)
- 支持用户名模糊搜索(如输入"test"可匹配"test123""admin_test"等)
- 仅返回用户ID、用户名、创建时间、更新时间不包含密码等敏感信息
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 计算分页偏移量page从1开始偏移量=(页码-1*每页条数)
offset = (page - 1) * page_size
# 基础查询(仅查非敏感字段)
base_query = """
SELECT id, username, created_at, updated_at
FROM users
"""
# 总条数查询(用于分页计算)
count_query = "SELECT COUNT(*) as total FROM users"
# 条件拼接(支持用户名模糊搜索)
conditions = []
params = []
if username:
conditions.append("username LIKE %s")
params.append(f"%{username}%") # 模糊匹配:%表示任意字符
# 构建最终查询语句
if conditions:
where_clause = " WHERE " + " AND ".join(conditions)
final_query = f"{base_query}{where_clause} LIMIT %s OFFSET %s"
final_count_query = f"{count_query}{where_clause}"
params.extend([page_size, offset]) # 追加分页参数
else:
final_query = f"{base_query} LIMIT %s OFFSET %s"
final_count_query = count_query
params = [page_size, offset]
# 1. 查询用户列表数据
cursor.execute(final_query, params)
users = cursor.fetchall()
# 2. 查询总条数(用于计算总页数)
count_params = [f"%{username}%"] if username else []
cursor.execute(final_count_query, count_params)
total = cursor.fetchone()["total"]
# 3. 转换为UserResponse模型确保字段匹配
user_list = [
UserResponse(
id=user["id"],
username=user["username"],
created_at=user["created_at"],
updated_at=user["updated_at"]
)
for user in users
]
# 4. 计算总页数向上取整如11条数据每页10条=2页
total_pages = (total + page_size - 1) // page_size
# 返回结果(包含列表和分页信息)
return APIResponse(
code=200,
message="获取用户列表成功",
data={
"users": user_list,
"pagination": {
"page": page, # 当前页码
"page_size": page_size, # 每页条数
"total": total, # 总数据量
"total_pages": total_pages # 总页数
}
}
)
except MySQLError as e:
raise Exception(f"获取用户列表失败: {str(e)}") from e
finally:
# 无论成功失败,都关闭数据库连接
db.close_connection(conn, cursor)

View File

@ -4,6 +4,11 @@ import insightface
from insightface.app import FaceAnalysis from insightface.app import FaceAnalysis
from io import BytesIO from io import BytesIO
from PIL import Image from PIL import Image
import logging
# 配置日志(便于排查)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - [FaceUtil] - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 全局变量存储InsightFace引擎和特征列表 # 全局变量存储InsightFace引擎和特征列表
_insightface_app = None _insightface_app = None
@ -11,135 +16,141 @@ _feature_list = []
def init_insightface(): def init_insightface():
"""初始化InsightFace引擎""" """初始化InsightFace引擎(确保成功后再使用)"""
global _insightface_app global _insightface_app
try: try:
print("正在初始化InsightFace引擎...") if _insightface_app is not None:
app = FaceAnalysis(name='buffalo_l', root='~/.insightface') logger.info("InsightFace引擎已初始化,无需重复执行")
app.prepare(ctx_id=0, det_size=(640, 640)) return _insightface_app
print("InsightFace引擎初始化完成")
logger.info("正在初始化InsightFace引擎模型buffalo_l...")
# 手动指定模型下载路径(避免权限问题,可选)
app = FaceAnalysis(
name='buffalo_l',
root='~/.insightface', # 模型默认下载路径
providers=['CPUExecutionProvider'] # 强制用CPU若有GPU可加'CUDAExecutionProvider'
)
app.prepare(ctx_id=0, det_size=(640, 640)) # det_size越大小人脸检测越准
logger.info("InsightFace引擎初始化完成")
_insightface_app = app _insightface_app = app
return app return app
except Exception as e: except Exception as e:
print(f"InsightFace初始化失败: {e}") logger.error(f"InsightFace初始化失败{str(e)}", exc_info=True) # 打印详细堆栈
_insightface_app = None
return None return None
def add_binary_data(binary_data): def add_binary_data(binary_data):
""" """
接收单张图片的二进制数据、提取特征并保存 接收单张图片的二进制数据、提取特征并保存
返回:(True, 特征值numpy数组) 或 (False, 错误信息字符串)
参数:
binary_data: 图片的二进制数据bytes类型
返回:
成功提取特征时返回 (True, 特征值numpy数组)
失败时返回 (False, None)
""" """
global _insightface_app, _feature_list global _insightface_app, _feature_list
# 1. 先检查引擎是否初始化成功
if not _insightface_app: if not _insightface_app:
print("引擎未初始化、无法处理") init_result = init_insightface() # 尝试重新初始化
return False, None if not init_result:
error_msg = "InsightFace引擎未初始化无法检测人脸"
logger.error(error_msg)
return False, error_msg
try: try:
# 直接处理二进制数据: 转换为图像格式 # 2. 验证二进制数据有效性
img = Image.open(BytesIO(binary_data)) if len(binary_data) < 1024: # 过滤过小的无效图片小于1KB
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) error_msg = f"图片过小({len(binary_data)}字节),可能不是有效图片"
logger.warning(error_msg)
return False, error_msg
# 提取特征 # 3. 二进制数据转CV2格式关键步骤避免通道错误
faces = _insightface_app.get(frame) try:
if faces: img = Image.open(BytesIO(binary_data)).convert("RGB") # 强制转RGB
# 获取当前提取的特征值 frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) # InsightFace需要BGR格式
current_feature = faces[0].embedding
# 添加到特征列表
_feature_list.append(current_feature)
print(f"已累计 {len(_feature_list)} 个特征")
# 返回成功标志和当前特征值
return True, current_feature
else:
print("二进制数据中未检测到人脸")
return False, None
except Exception as e: except Exception as e:
print(f"处理二进制数据出错: {e}") error_msg = f"图片格式转换失败:{str(e)}"
return False, None logger.error(error_msg, exc_info=True)
return False, error_msg
# 4. 检查图片尺寸(避免极端尺寸导致检测失败)
height, width = frame.shape[:2]
if height < 64 or width < 64: # 人脸检测最小建议尺寸
error_msg = f"图片尺寸过小({width}x{height}需至少64x64像素"
logger.warning(error_msg)
return False, error_msg
# 5. 调用InsightFace检测人脸
logger.info(f"开始检测人脸(图片尺寸:{width}x{height}格式BGR")
faces = _insightface_app.get(frame)
if not faces:
error_msg = "未检测到人脸(请确保图片包含清晰正面人脸,无遮挡、不模糊)"
logger.warning(error_msg)
return False, error_msg
# 6. 提取特征并保存
current_feature = faces[0].embedding
_feature_list.append(current_feature)
logger.info(f"人脸检测成功,提取特征值(维度:{current_feature.shape[0]}),累计特征数:{len(_feature_list)}")
return True, current_feature
except Exception as e:
error_msg = f"处理图片时发生异常:{str(e)}"
logger.error(error_msg, exc_info=True)
return False, error_msg
# 以下函数保持不变get_average_feature/clear_features/get_feature_list
def get_average_feature(features=None): def get_average_feature(features=None):
"""
计算多个特征向量的平均值
参数:
features: 可选、特征值列表。如果未提供、则使用全局存储的_feature_list
每个元素可以是字符串格式或numpy数组
返回:
单一平均特征向量的numpy数组、若无可计算数据则返回None
"""
global _feature_list global _feature_list
try:
# 如果未提供features参数、则使用全局特征列表
if features is None: if features is None:
features = _feature_list features = _feature_list
try:
# 验证输入是否为列表且不为空
if not isinstance(features, list) or len(features) == 0: if not isinstance(features, list) or len(features) == 0:
print("输入必须是包含至少一个特征值的列表") logger.warning("输入必须是包含至少一个特征值的列表")
return None return None
# 处理每个特征值
processed_features = [] processed_features = []
for i, embedding in enumerate(features): for i, embedding in enumerate(features):
try: try:
if isinstance(embedding, str): if isinstance(embedding, str):
# 处理包含括号和逗号的字符串格式
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip() embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
embedding_list = [float(num) for num in embedding_str.split() if num.strip()] embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
embedding_np = np.array(embedding_list, dtype=np.float32) embedding_np = np.array(embedding_list, dtype=np.float32)
else: else:
embedding_np = np.array(embedding, dtype=np.float32) embedding_np = np.array(embedding, dtype=np.float32)
# 验证特征值格式
if len(embedding_np.shape) == 1: if len(embedding_np.shape) == 1:
processed_features.append(embedding_np) processed_features.append(embedding_np)
print(f"已添加第 {i + 1} 个特征值用于计算平均值") logger.info(f"已添加第 {i + 1} 个特征值用于计算平均值")
else: else:
print(f"跳过第 {i + 1} 个特征值不是一维数组") logger.warning(f"跳过第 {i + 1} 个特征值不是一维数组")
except Exception as e: except Exception as e:
print(f"处理第 {i + 1} 个特征值时出错: {e}") logger.error(f"处理第 {i + 1} 个特征值时出错{str(e)}")
# 确保有有效的特征值
if not processed_features: if not processed_features:
print("没有有效的特征值用于计算平均值") logger.warning("没有有效的特征值用于计算平均值")
return None return None
# 检查所有特征向量维度是否相同
dims = {feat.shape[0] for feat in processed_features} dims = {feat.shape[0] for feat in processed_features}
if len(dims) > 1: if len(dims) > 1:
print(f"特征值维度不一致、无法计算平均值。检测到的维度: {dims}") logger.error(f"特征值维度不一致{dims},无法计算平均值")
return None return None
# 计算平均值
avg_feature = np.mean(processed_features, axis=0) avg_feature = np.mean(processed_features, axis=0)
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量维度: {avg_feature.shape[0]}") logger.info(f"计算成功:{len(processed_features)} 个特征值的平均向量维度{avg_feature.shape[0]}")
return avg_feature return avg_feature
except Exception as e: except Exception as e:
print(f"计算平均特征值出错: {e}") logger.error(f"计算平均特征值出错{str(e)}", exc_info=True)
return None return None
def clear_features(): def clear_features():
"""清空已存储的特征数据"""
global _feature_list global _feature_list
_feature_list = [] _feature_list = []
print("已清空所有特征数据") logger.info("已清空所有特征数据")
def get_feature_list(): def get_feature_list():
"""获取当前存储的特征列表"""
global _feature_list global _feature_list
return _feature_list.copy() # 返回副本防止外部直接修改 logger.info(f"当前特征列表长度:{len(_feature_list)}")
return _feature_list.copy()

83
util/file_util.py Normal file
View File

@ -0,0 +1,83 @@
import os
import datetime
from pathlib import Path
from typing import Dict
def save_face_to_up_images(
client_ip: str,
face_name: str,
image_bytes: bytes,
image_format: str = "jpg"
) -> Dict[str, str]:
"""
保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径
修复路径计算错误确保所有路径在up_images根目录下
参数:
client_ip: 客户端IP原始格式如192.168.1.101
face_name: 人脸名称(用户输入,可为空)
image_bytes: 人脸图片二进制数据
image_format: 图片格式默认jpg
返回:
字典success是否成功、db_path存数据库的相对路径、local_abs_path本地绝对路径、msg提示
"""
try:
# 1. 基础参数校验
if not client_ip.strip():
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "客户端IP不能为空"}
if not image_bytes:
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "图片二进制数据为空"}
if image_format.lower() not in ["jpg", "jpeg", "png"]:
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "仅支持jpg/jpeg/png格式"}
# 2. 处理特殊字符(避免路径错误)
safe_ip = client_ip.strip().replace(".", "_") # IP中的.替换为_
safe_face_name = face_name.strip() if (face_name and face_name.strip()) else "未命名"
safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|']) # 过滤非法字符
# 3. 构建根目录(强制转为绝对路径,避免相对路径混淆)
root_dir = Path("up_images").resolve() # 转为绝对路径(关键修复!)
if not root_dir.exists():
root_dir.mkdir(parents=True, exist_ok=True)
print(f"[FileUtil] 已创建up_images根目录{root_dir}")
# 4. 构建文件层级路径确保在root_dir子目录下
ip_dir = root_dir / safe_ip
face_name_dir = ip_dir / safe_face_name
face_name_dir.mkdir(parents=True, exist_ok=True) # 自动创建目录
print(f"[FileUtil] 图片存储目录:{face_name_dir}")
# 5. 生成唯一文件名(毫秒级时间戳)
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}"
# 6. 计算路径关键修复确保所有路径都是绝对路径且在root_dir下
local_abs_path = face_name_dir / image_filename # 绝对路径
# 验证路径是否在root_dir下防止路径穿越攻击
if not local_abs_path.resolve().is_relative_to(root_dir.resolve()):
raise Exception(f"图片路径不在up_images根目录下安全校验失败{local_abs_path}")
# 数据库存储路径从root_dir开始的相对路径如 up_images/192_168_110_31/小王/xxx.jpg
db_path = str(root_dir.name / local_abs_path.relative_to(root_dir))
# 7. 写入图片文件
with open(local_abs_path, "wb") as f:
f.write(image_bytes)
print(f"[FileUtil] 图片保存成功:")
print(f" 数据库路径:{db_path}")
print(f" 本地绝对路径:{local_abs_path}")
return {
"success": True,
"db_path": db_path, # 存数据库的相对路径up_images开头
"local_abs_path": str(local_abs_path), # 本地绝对路径
"msg": "图片保存成功"
}
except Exception as e:
error_msg = f"图片保存失败:{str(e)}"
print(f"[FileUtil] 错误:{error_msg}")
return {"success": False, "db_path": "", "local_abs_path": "", "msg": error_msg}

61
util/model_util.py Normal file
View File

@ -0,0 +1,61 @@
import os
import numpy as np
import traceback
from ultralytics import YOLO
from typing import Optional
def load_yolo_model(model_path: str) -> Optional[YOLO]:
"""
加载YOLO模型支持v5/v8并校验模型有效性
:param model_path: 模型文件的绝对路径
:return: 加载成功返回YOLO模型实例失败返回None
"""
try:
# 加载前的基础信息检查
print(f"\n[模型工具] 开始加载模型:{model_path}")
print(f"[模型工具] 文件是否存在:{os.path.exists(model_path)}")
if os.path.exists(model_path):
print(f"[模型工具] 文件大小:{os.path.getsize(model_path) / 1024 / 1024:.2f} MB")
# 强制重新加载模型,避免缓存问题
model = YOLO(model_path)
# 兼容性校验使用numpy空数组测试模型
dummy_image = np.zeros((640, 640, 3), dtype=np.uint8)
try:
# 优先使用新版本参数
model.predict(
source=dummy_image,
imgsz=640,
conf=0.25,
verbose=False,
stream=False
)
except Exception as pred_e:
print(f"[模型工具] 预测校验兼容处理:{str(pred_e)}")
# 兼容旧版本YOLO参数
model.predict(
img=dummy_image,
imgsz=640,
conf=0.25,
verbose=False
)
# 验证模型基本属性
if not hasattr(model, 'names'):
print("[模型工具] 警告:模型缺少类别名称属性")
else:
print(f"[模型工具] 模型包含类别:{list(model.names.values())[:5]}...") # 显示前5个类别
print(f"[模型工具] 模型加载成功!")
return model
except Exception as e:
# 详细错误信息输出
print(f"\n[模型工具] 加载模型失败!路径:{model_path}")
print(f"[模型工具] 异常类型:{type(e).__name__}")
print(f"[模型工具] 异常详情:{str(e)}")
print(f"[模型工具] 堆栈跟踪:\n{traceback.format_exc()}")
return None