目前可以成功动态更换模型运行的
This commit is contained in:
283
app.py
Normal file
283
app.py
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
from flask import Flask, send_from_directory, abort, request
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from functools import wraps
|
||||||
|
from pathlib import Path
|
||||||
|
# 跨域依赖(必须安装:pip install flask-cors)
|
||||||
|
from flask_cors import CORS
|
||||||
|
|
||||||
|
# 配置日志(保持原有格式)
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 初始化 Flask 应用(供 main.py 导入)
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 核心修改:与 FastAPI 对齐的跨域配置
|
||||||
|
# ------------------------------
|
||||||
|
# 1. 允许的前端域名(完全复制 FastAPI 的 ALLOWED_ORIGINS,确保前后端一致)
|
||||||
|
ALLOWED_ORIGINS = [
|
||||||
|
# "http://localhost:8080", # 本地前端开发地址(必改:替换为你的前端实际地址)
|
||||||
|
# "http://127.0.0.1:8080",
|
||||||
|
# "http://服务器IP:8080", # 部署后前端地址(替换为你的服务器IP/域名)
|
||||||
|
# # "*" 仅开发环境临时使用,生产环境必须删除(安全风险)
|
||||||
|
"*"
|
||||||
|
]
|
||||||
|
|
||||||
|
# 2. 配置 CORS(与 FastAPI 规则完全对齐)
|
||||||
|
CORS(
|
||||||
|
app,
|
||||||
|
resources={
|
||||||
|
r"/*": { # 对所有 Flask 路由生效(覆盖图片、模型下载所有接口)
|
||||||
|
"origins": ALLOWED_ORIGINS, # 允许的前端域名(与 FastAPI 一致)
|
||||||
|
"allow_credentials": True, # 允许携带 Cookie(与 FastAPI 一致,需登录态必开)
|
||||||
|
"methods": ["*"], # 允许所有 HTTP 方法(FastAPI 用 "*",此处同步)
|
||||||
|
"allow_headers": ["*"], # 允许所有请求头(与 FastAPI 一致)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 核心路径配置(不变,确保资源目录正确)
|
||||||
|
# ------------------------------
|
||||||
|
CURRENT_FILE_PATH = Path(__file__).resolve()
|
||||||
|
PROJECT_ROOT = CURRENT_FILE_PATH.parent # 项目根目录(video/)
|
||||||
|
# 资源目录(图片、模型)
|
||||||
|
BASE_IMAGE_DIR_DECT = str((PROJECT_ROOT / "resource" / "dect").resolve()) # 检测图片目录
|
||||||
|
BASE_IMAGE_DIR_UP_IMAGES = str((PROJECT_ROOT / "up_images").resolve()) # 人脸图片目录
|
||||||
|
BASE_MODEL_DIR = str((PROJECT_ROOT / "resource" / "models").resolve()) # 模型文件目录
|
||||||
|
|
||||||
|
# 打印路径配置(调试用,确认目录正确)
|
||||||
|
logger.info(f"[Flask 配置] 项目根目录:{PROJECT_ROOT}")
|
||||||
|
logger.info(f"[Flask 配置] 模型目录:{BASE_MODEL_DIR}")
|
||||||
|
logger.info(f"[Flask 配置] 人脸图片目录:{BASE_IMAGE_DIR_UP_IMAGES}")
|
||||||
|
logger.info(f"[Flask 配置] 检测图片目录:{BASE_IMAGE_DIR_DECT}")
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 安全检查装饰器(不变,防路径遍历/非法文件)
|
||||||
|
# ------------------------------
|
||||||
|
def safe_path_check(root_dir: str):
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
resource_path = kwargs.get('resource_path', '').strip()
|
||||||
|
# 统一路径分隔符(兼容 Windows \ 和 Linux /)
|
||||||
|
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
|
||||||
|
# 拼接完整路径(防止路径遍历)
|
||||||
|
full_file_path = os.path.abspath(os.path.join(root_dir, resource_path))
|
||||||
|
logger.debug(
|
||||||
|
f"[Flask 安全检查] 请求路径:{resource_path} | 完整路径:{full_file_path} | 根目录:{root_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. 禁止路径遍历(确保请求文件在根目录内)
|
||||||
|
if not full_file_path.startswith(root_dir):
|
||||||
|
logger.warning(
|
||||||
|
f"[Flask 安全拦截] 非法路径遍历!IP:{request.remote_addr} | 请求路径:{resource_path}"
|
||||||
|
)
|
||||||
|
abort(403)
|
||||||
|
|
||||||
|
# 2. 检查文件存在且为有效文件(非目录)
|
||||||
|
if not os.path.exists(full_file_path) or not os.path.isfile(full_file_path):
|
||||||
|
logger.warning(
|
||||||
|
f"[Flask 资源错误] 文件不存在/非文件!IP:{request.remote_addr} | 路径:{full_file_path}"
|
||||||
|
)
|
||||||
|
abort(404)
|
||||||
|
|
||||||
|
# 3. 限制文件大小(模型200MB,图片10MB,避免超大文件攻击)
|
||||||
|
max_size = 200 * 1024 * 1024 if "models" in root_dir else 10 * 1024 * 1024
|
||||||
|
if os.path.getsize(full_file_path) > max_size:
|
||||||
|
logger.warning(
|
||||||
|
f"[Flask 大小超限] 文件超过{max_size//1024//1024}MB!IP:{request.remote_addr} | 路径:{full_file_path}"
|
||||||
|
)
|
||||||
|
abort(413)
|
||||||
|
|
||||||
|
# 安全检查通过,传递根目录给视图函数
|
||||||
|
return func(*args, **kwargs, root_dir=root_dir)
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 1. 模型下载接口(/model/download/*)
|
||||||
|
# ------------------------------
|
||||||
|
@app.route('/model/download/<path:resource_path>')
|
||||||
|
@safe_path_check(root_dir=BASE_MODEL_DIR)
|
||||||
|
def download_model(resource_path, root_dir):
|
||||||
|
try:
|
||||||
|
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
|
||||||
|
dir_path, file_name = os.path.split(resource_path)
|
||||||
|
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
|
||||||
|
|
||||||
|
# 仅允许 .pt 格式(YOLO 模型)
|
||||||
|
if not file_name.lower().endswith('.pt'):
|
||||||
|
logger.warning(
|
||||||
|
f"[Flask 格式错误] 非 .pt 模型文件!IP:{request.remote_addr} | 文件名:{file_name}"
|
||||||
|
)
|
||||||
|
abort(415)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[Flask 模型下载] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 强制浏览器下载(而非预览),设置二进制文件类型
|
||||||
|
return send_from_directory(
|
||||||
|
full_dir,
|
||||||
|
file_name,
|
||||||
|
as_attachment=True,
|
||||||
|
mimetype="application/octet-stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[Flask 模型下载异常] IP:{request.remote_addr} | 错误:{str(e)}"
|
||||||
|
)
|
||||||
|
abort(500)
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 2. 人脸图片访问接口(/up_images/*)
|
||||||
|
# ------------------------------
|
||||||
|
@app.route('/up_images/<path:resource_path>')
|
||||||
|
@safe_path_check(root_dir=BASE_IMAGE_DIR_UP_IMAGES)
|
||||||
|
def get_face_image(resource_path, root_dir):
|
||||||
|
try:
|
||||||
|
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
|
||||||
|
dir_path, file_name = os.path.split(resource_path)
|
||||||
|
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
|
||||||
|
|
||||||
|
# 仅允许常见图片格式
|
||||||
|
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
|
||||||
|
if not file_name.lower().endswith(allowed_ext):
|
||||||
|
logger.warning(
|
||||||
|
f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}"
|
||||||
|
)
|
||||||
|
abort(415)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[Flask 人脸图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 允许浏览器预览图片(而非下载)
|
||||||
|
return send_from_directory(full_dir, file_name, as_attachment=False)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[Flask 人脸图片异常] IP:{request.remote_addr} | 错误:{str(e)}"
|
||||||
|
)
|
||||||
|
abort(500)
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 3. 检测图片访问接口(/resource/dect/*)
|
||||||
|
# ------------------------------
|
||||||
|
@app.route('/resource/dect/<path:resource_path>')
|
||||||
|
@safe_path_check(root_dir=BASE_IMAGE_DIR_DECT)
|
||||||
|
def get_dect_image(resource_path, root_dir):
|
||||||
|
try:
|
||||||
|
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
|
||||||
|
dir_path, file_name = os.path.split(resource_path)
|
||||||
|
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
|
||||||
|
|
||||||
|
# 仅允许常见图片格式
|
||||||
|
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
|
||||||
|
if not file_name.lower().endswith(allowed_ext):
|
||||||
|
logger.warning(
|
||||||
|
f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}"
|
||||||
|
)
|
||||||
|
abort(415)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[Flask 检测图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return send_from_directory(full_dir, file_name, as_attachment=False)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[Flask 检测图片异常] IP:{request.remote_addr} | 错误:{str(e)}"
|
||||||
|
)
|
||||||
|
abort(500)
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 4. 兼容旧图片接口(/images/* → 映射到 /resource/dect/*)
|
||||||
|
# ------------------------------
|
||||||
|
@app.route('/images/<path:resource_path>')
|
||||||
|
@safe_path_check(root_dir=BASE_IMAGE_DIR_DECT)
|
||||||
|
def get_compatible_image(resource_path, root_dir):
|
||||||
|
try:
|
||||||
|
# 逻辑与检测图片接口一致,仅URL前缀不同(兼容旧前端)
|
||||||
|
resource_path = resource_path.replace("/", os.sep).replace("\\", os.sep)
|
||||||
|
dir_path, file_name = os.path.split(resource_path)
|
||||||
|
full_dir = os.path.abspath(os.path.join(root_dir, dir_path))
|
||||||
|
|
||||||
|
allowed_ext = ('.png', '.jpg', '.jpeg', '.gif', '.bmp')
|
||||||
|
if not file_name.lower().endswith(allowed_ext):
|
||||||
|
logger.warning(
|
||||||
|
f"[Flask 格式错误] 非图片文件!IP:{request.remote_addr} | 文件名:{file_name}"
|
||||||
|
)
|
||||||
|
abort(415)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[Flask 兼容图片] 成功请求!IP:{request.remote_addr} | 文件:{file_name} | 目录:{full_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return send_from_directory(full_dir, file_name, as_attachment=False)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[Flask 兼容图片异常] IP:{request.remote_addr} | 错误:{str(e)}"
|
||||||
|
)
|
||||||
|
abort(500)
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 全局错误处理器(友好提示,与 FastAPI 错误信息风格一致)
|
||||||
|
# ------------------------------
|
||||||
|
@app.errorhandler(403)
|
||||||
|
def forbidden_error(error):
|
||||||
|
return "❌ 禁止访问:路径非法(可能存在路径遍历)或无权限", 403
|
||||||
|
|
||||||
|
@app.errorhandler(404)
|
||||||
|
def not_found_error(error):
|
||||||
|
return "❌ 资源不存在:请检查URL路径(IP、目录、文件名)是否正确", 404
|
||||||
|
|
||||||
|
@app.errorhandler(413)
|
||||||
|
def too_large_error(error):
|
||||||
|
return "❌ 文件过大:图片最大10MB,模型最大200MB", 413
|
||||||
|
|
||||||
|
@app.errorhandler(415)
|
||||||
|
def unsupported_type_error(error):
|
||||||
|
return "❌ 不支持的文件类型:图片支持png/jpg/jpeg/gif/bmp,模型仅支持pt", 415
|
||||||
|
|
||||||
|
@app.errorhandler(500)
|
||||||
|
def server_error(error):
|
||||||
|
return "❌ 服务器内部错误:请联系管理员查看后台日志", 500
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# Flask 独立启动入口(供测试,实际由 main.py 子线程启动)
|
||||||
|
# ------------------------------
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 确保所有资源目录存在(防止初始化失败)
|
||||||
|
required_dirs = [
|
||||||
|
(BASE_IMAGE_DIR_DECT, "检测图片目录"),
|
||||||
|
(BASE_IMAGE_DIR_UP_IMAGES, "人脸图片目录"),
|
||||||
|
(BASE_MODEL_DIR, "模型文件目录")
|
||||||
|
]
|
||||||
|
for dir_path, dir_desc in required_dirs:
|
||||||
|
if not os.path.exists(dir_path):
|
||||||
|
logger.info(f"[Flask 初始化] {dir_desc}不存在,创建:{dir_path}")
|
||||||
|
os.makedirs(dir_path, exist_ok=True)
|
||||||
|
|
||||||
|
# 启动提示(含访问示例)
|
||||||
|
logger.info("\n[Flask 服务启动成功!] 支持的接口:")
|
||||||
|
logger.info(f"1. 模型下载 → http://服务器IP:5000/model/download/resource/models/xxx.pt")
|
||||||
|
logger.info(f"2. 人脸图片 → http://服务器IP:5000/up_images/xxx.jpg")
|
||||||
|
logger.info(f"3. 检测图片 → http://服务器IP:5000/resource/dect/xxx.jpg 或 http://服务器IP:5000/images/xxx.jpg\n")
|
||||||
|
|
||||||
|
# 启动服务(禁用 debug 和自动重载,避免多线程冲突)
|
||||||
|
app.run(
|
||||||
|
host="0.0.0.0", # 允许外部IP访问
|
||||||
|
port=5000, # 与 main.py 中 Flask 端口一致
|
||||||
|
debug=False,
|
||||||
|
use_reloader=False
|
||||||
|
)
|
41
core/all.py
41
core/all.py
@ -57,50 +57,39 @@ def save_db(model_type, client_ip, result):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 修正后的 detect 函数关键部分
|
||||||
def detect(client_ip, frame):
|
def detect(client_ip, frame):
|
||||||
"""
|
# 1. YOLO检测
|
||||||
执行模型检测,检测到违规时按指定格式保存图片
|
|
||||||
参数:
|
|
||||||
frame: 待检测的图像帧(OpenCV格式,numpy.ndarray类型)
|
|
||||||
返回:
|
|
||||||
(检测结果布尔值, 检测详情, 检测模型类型)
|
|
||||||
"""
|
|
||||||
# 1. YOLO检测(优先级1)
|
|
||||||
yolo_flag, yolo_result = yoloDetect(frame)
|
yolo_flag, yolo_result = yoloDetect(frame)
|
||||||
print(f"YOLO检测结果:{yolo_result}")
|
|
||||||
if yolo_flag:
|
if yolo_flag:
|
||||||
|
# model_type 传入 "yolo"(正确)
|
||||||
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
|
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
|
||||||
if full_save_path: # 只判断完整路径是否有效(用于保存)
|
if full_save_path:
|
||||||
cv2.imwrite(full_save_path, frame)
|
cv2.imwrite(full_save_path, frame)
|
||||||
# 打印时使用「显示用短路径」,符合需求格式
|
print(f"✅ yolo违规图片已保存:{display_path}") # 日志也修正
|
||||||
print(f"✅ YOLO违规图片已保存:{display_path}")
|
|
||||||
save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path))
|
save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path))
|
||||||
return (True, yolo_result, "yolo")
|
return (True, yolo_result, "yolo")
|
||||||
#
|
|
||||||
# # 2. 人脸检测(优先级2)
|
# 2. 人脸检测
|
||||||
face_flag, face_result = faceDetect(frame)
|
face_flag, face_result = faceDetect(frame)
|
||||||
print(f"人脸检测结果:{face_result}")
|
|
||||||
if face_flag:
|
if face_flag:
|
||||||
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
|
full_save_path, display_path = get_image_save_path(model_type="face", client_ip=client_ip) # 这里改了
|
||||||
if full_save_path: # 只判断完整路径是否有效(用于保存)
|
if full_save_path:
|
||||||
cv2.imwrite(full_save_path, frame)
|
cv2.imwrite(full_save_path, frame)
|
||||||
# 打印时使用「显示用短路径」,符合需求格式
|
print(f"✅ face违规图片已保存:{display_path}") # 日志也修正
|
||||||
print(f"✅ face违规图片已保存:{display_path}")
|
|
||||||
save_db(model_type="face", client_ip=client_ip, result=str(full_save_path))
|
save_db(model_type="face", client_ip=client_ip, result=str(full_save_path))
|
||||||
return (True, face_result, "face")
|
return (True, face_result, "face")
|
||||||
|
|
||||||
# 3. OCR检测(优先级3)
|
# 3. OCR检测
|
||||||
ocr_flag, ocr_result = ocrDetect(frame)
|
ocr_flag, ocr_result = ocrDetect(frame)
|
||||||
print(f"OCR检测结果:{ocr_result}")
|
|
||||||
if ocr_flag:
|
if ocr_flag:
|
||||||
# 解构元组,保存用完整路径,打印用短路径
|
full_save_path, display_path = get_image_save_path(model_type="ocr", client_ip=client_ip) # 这里改了
|
||||||
full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip)
|
if full_save_path:
|
||||||
if full_save_path: # 只判断完整路径是否有效(用于保存)
|
|
||||||
cv2.imwrite(full_save_path, frame)
|
cv2.imwrite(full_save_path, frame)
|
||||||
# 打印时使用「显示用短路径」,符合需求格式
|
print(f"✅ ocr违规图片已保存:{display_path}") # 日志也修正
|
||||||
print(f"✅ ocr违规图片已保存:{display_path}")
|
|
||||||
save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path))
|
save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path))
|
||||||
return (True, ocr_result, "ocr")
|
return (True, ocr_result, "ocr")
|
||||||
|
|
||||||
# 4. 无违规内容(不保存图片)
|
# 4. 无违规内容(不保存图片)
|
||||||
print(f"❌ 未检测到任何违规内容,不保存图片")
|
print(f"❌ 未检测到任何违规内容,不保存图片")
|
||||||
return (False, "未检测到任何内容", "none")
|
return (False, "未检测到任何内容", "none")
|
47
core/yolo.py
47
core/yolo.py
@ -1,37 +1,43 @@
|
|||||||
import os
|
import os
|
||||||
|
import numpy as np
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
from service.model_service import get_current_yolo_model # 从模型管理模块获取模型
|
||||||
|
|
||||||
# 全局变量
|
# 全局模型变量
|
||||||
_yolo_model = None
|
_yolo_model = None
|
||||||
|
|
||||||
|
|
||||||
model_path = os.path.join(os.path.dirname(__file__), "models", "best.pt")
|
def load_model(model_path=None):
|
||||||
|
"""加载YOLO模型(优先使用模型管理模块的默认模型)"""
|
||||||
|
|
||||||
def load_model():
|
|
||||||
"""加载YOLO目标检测模型"""
|
|
||||||
global _yolo_model
|
global _yolo_model
|
||||||
|
|
||||||
|
if model_path is None:
|
||||||
|
_yolo_model = get_current_yolo_model()
|
||||||
|
return _yolo_model is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_yolo_model = YOLO(model_path)
|
_yolo_model = YOLO(model_path)
|
||||||
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"YOLO model load failed: {e}")
|
print(f"YOLO模型加载失败(指定路径):{str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True if _yolo_model else False
|
|
||||||
|
|
||||||
|
|
||||||
def detect(frame, conf_threshold=0.2):
|
def detect(frame, conf_threshold=0.2):
|
||||||
"""YOLO目标检测、返回(是否识别到, 结果字符串)"""
|
"""执行目标检测,返回(是否成功, 结果字符串)"""
|
||||||
global _yolo_model
|
global _yolo_model
|
||||||
|
|
||||||
if not _yolo_model or frame is None:
|
# 确保模型已加载
|
||||||
return (False, "未初始化或无效帧")
|
if not _yolo_model:
|
||||||
|
if not load_model():
|
||||||
|
return (False, "模型未初始化")
|
||||||
|
|
||||||
|
if frame is None:
|
||||||
|
return (False, "无效输入帧")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results = _yolo_model(frame, conf=conf_threshold)
|
# 执行检测(frame应为numpy数组)
|
||||||
# 检查是否有检测结果
|
results = _yolo_model(frame, conf=conf_threshold, verbose=False)
|
||||||
has_results = len(results[0].boxes) > 0 if results else False
|
has_results = len(results[0].boxes) > 0 if results else False
|
||||||
|
|
||||||
if not has_results:
|
if not has_results:
|
||||||
@ -42,13 +48,12 @@ def detect(frame, conf_threshold=0.2):
|
|||||||
for box in results[0].boxes:
|
for box in results[0].boxes:
|
||||||
cls = int(box.cls[0])
|
cls = int(box.cls[0])
|
||||||
conf = float(box.conf[0])
|
conf = float(box.conf[0])
|
||||||
bbox = [float(x) for x in box.xyxy[0]]
|
bbox = [round(x, 2) for x in box.xyxy[0].tolist()] # 保留两位小数
|
||||||
class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}"
|
class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}"
|
||||||
result_parts.append(f"{class_name} (置信度: {conf:.2f}, 边界框: {bbox})")
|
result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox})")
|
||||||
|
|
||||||
result_str = "; ".join(result_parts)
|
return (True, "; ".join(result_parts))
|
||||||
return (has_results, result_str)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"YOLO detect error: {e}")
|
print(f"检测过程出错:{str(e)}")
|
||||||
return (False, f"检测错误: {str(e)}")
|
return (False, f"检测错误:{str(e)}")
|
||||||
|
130
main.py
130
main.py
@ -1,62 +1,142 @@
|
|||||||
from PIL import Image # 正确导入
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import uvicorn
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import uvicorn
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import os
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from core.all import load_model,detect
|
# 新增:导入 CORS 相关依赖
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
# 导入 Flask 服务实例
|
||||||
|
from app import app as flask_app
|
||||||
|
|
||||||
|
# 原有业务导入
|
||||||
|
from core.all import load_model, detect
|
||||||
from ds.config import SERVER_CONFIG
|
from ds.config import SERVER_CONFIG
|
||||||
from middle.error_handler import global_exception_handler
|
from middle.error_handler import global_exception_handler
|
||||||
from service.user_service import router as user_router
|
from service.user_service import router as user_router
|
||||||
from service.sensitive_service import router as sensitive_router
|
from service.sensitive_service import router as sensitive_router
|
||||||
from service.face_service import router as face_router
|
from service.face_service import router as face_router
|
||||||
from service.device_service import router as device_router
|
from service.device_service import router as device_router
|
||||||
|
from service.model_service import router as model_router # 模型管理路由
|
||||||
from ws.ws import ws_router, lifespan
|
from ws.ws import ws_router, lifespan
|
||||||
from core.establish import create_directory_structure
|
from core.establish import create_directory_structure
|
||||||
|
|
||||||
# ------------------------------
|
|
||||||
# 初始化 FastAPI 应用、指定生命周期管理
|
# Flask 服务启动函数(不变)
|
||||||
# ------------------------------
|
def start_flask_service():
|
||||||
|
try:
|
||||||
|
print(f"\n[Flask 服务] 准备启动,端口:5000")
|
||||||
|
print(f"[Flask 服务] 访问示例:http://服务器IP:5000/resource/dect/ocr/xxx.jpg\n")
|
||||||
|
|
||||||
|
BASE_IMAGE_DIR = os.path.abspath(os.path.join("resource", "dect"))
|
||||||
|
if not os.path.exists(BASE_IMAGE_DIR):
|
||||||
|
print(f"[Flask 服务] 图片根目录不存在,创建:{BASE_IMAGE_DIR}")
|
||||||
|
os.makedirs(BASE_IMAGE_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
flask_app.run(
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=5000,
|
||||||
|
debug=False,
|
||||||
|
use_reloader=False
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Flask 服务] 启动失败:{str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# 初始化 FastAPI 应用(新增 CORS 配置)
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="内容安全审核后台",
|
title="内容安全审核后台",
|
||||||
description="内容安全审核后台",
|
description="含图片访问服务和动态模型管理",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
lifespan=lifespan
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 注册路由
|
# 新增:完整 CORS 配置(解决跨域问题)
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
|
# 1. 允许的前端域名(根据实际情况修改!本地开发通常是 http://localhost:8080 等)
|
||||||
|
ALLOWED_ORIGINS = [
|
||||||
|
# "http://localhost:8080", # 前端本地开发地址(必改,填实际前端地址)
|
||||||
|
# "http://127.0.0.1:8080",
|
||||||
|
# "http://服务器IP:8080", # 部署后前端地址(如适用)
|
||||||
|
"*" #表示允许所有域名(开发环境可用,生产环境不推荐)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 2. 配置 CORS 中间件
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=ALLOWED_ORIGINS, # 允许的前端域名
|
||||||
|
allow_credentials=True, # 允许携带 Cookie(如需登录态则必开)
|
||||||
|
allow_methods=["*"], # 允许所有 HTTP 方法(包括 PUT/DELETE)
|
||||||
|
allow_headers=["*"], # 允许所有请求头(包括 Content-Type)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册路由(不变)
|
||||||
app.include_router(user_router)
|
app.include_router(user_router)
|
||||||
app.include_router(device_router)
|
app.include_router(device_router)
|
||||||
app.include_router(face_router)
|
app.include_router(face_router)
|
||||||
app.include_router(sensitive_router)
|
app.include_router(sensitive_router)
|
||||||
|
app.include_router(model_router) # 模型管理路由
|
||||||
app.include_router(ws_router)
|
app.include_router(ws_router)
|
||||||
|
|
||||||
# ------------------------------
|
# 注册全局异常处理器(不变)
|
||||||
# 注册全局异常处理器
|
|
||||||
# ------------------------------
|
|
||||||
app.add_exception_handler(Exception, global_exception_handler)
|
app.add_exception_handler(Exception, global_exception_handler)
|
||||||
|
|
||||||
# ------------------------------
|
# 主服务启动入口(不变)
|
||||||
# 启动服务
|
|
||||||
# ------------------------------
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# -------------------------- 配置调整 --------------------------
|
# 1. 初始化资源
|
||||||
# 模型配置路径(建议改为环境变量)
|
|
||||||
YOLO_MODEL_PATH = r"/core/models\best.pt"
|
|
||||||
OCR_CONFIG_PATH = r"/core/config\config.yaml"
|
|
||||||
|
|
||||||
create_directory_structure()
|
create_directory_structure()
|
||||||
|
print(f"[初始化] 目录结构创建完成")
|
||||||
|
|
||||||
# 初始化项目(默认端口设为8000、避免初始化失败时port未定义)
|
# 创建模型保存目录
|
||||||
|
MODEL_SAVE_DIR = os.path.join("core", "models")
|
||||||
|
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
|
||||||
|
print(f"[初始化] 模型保存目录:{MODEL_SAVE_DIR}")
|
||||||
|
|
||||||
|
# # 模型路径配置
|
||||||
|
# YOLO_MODEL_PATH = os.path.join("core", "models", "best.pt")
|
||||||
|
# OCR_CONFIG_PATH = os.path.join("core", "config", "config.yaml")
|
||||||
|
# print(f"[初始化] 默认YOLO模型路径:{YOLO_MODEL_PATH}")
|
||||||
|
# print(f"[初始化] OCR 配置路径:{OCR_CONFIG_PATH}")
|
||||||
|
|
||||||
|
# 加载检测模型
|
||||||
|
try:
|
||||||
|
load_success = load_model()
|
||||||
|
if load_success:
|
||||||
|
print(f"[初始化] 检测模型加载完成")
|
||||||
|
else:
|
||||||
|
print(f"[初始化] 未找到默认模型,可通过API上传并设置默认模型")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[初始化] 模型加载警告:{str(e)}(服务仍启动)")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 2. 启动 Flask 服务(子线程)
|
||||||
|
flask_thread = threading.Thread(
|
||||||
|
target=start_flask_service,
|
||||||
|
daemon=True
|
||||||
|
)
|
||||||
|
flask_thread.start()
|
||||||
|
|
||||||
|
# 等待 Flask 初始化
|
||||||
|
time.sleep(1)
|
||||||
|
if flask_thread.is_alive():
|
||||||
|
print(f"[Flask 服务] 启动成功(运行中)")
|
||||||
|
else:
|
||||||
|
print(f"[Flask 服务] 启动失败!图片访问不可用")
|
||||||
|
|
||||||
|
# 3. 启动 FastAPI 主服务
|
||||||
port = int(SERVER_CONFIG.get("port", 8000))
|
port = int(SERVER_CONFIG.get("port", 8000))
|
||||||
|
print(f"\n[FastAPI 服务] 准备启动,端口:{port}")
|
||||||
|
print(f"[FastAPI 服务] 接口文档:http://服务器IP:{port}/docs\n")
|
||||||
|
|
||||||
# 启动 UVicorn 服务
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app="main:app",
|
app="main:app",
|
||||||
host="0.0.0.0",
|
host="0.0.0.0",
|
||||||
port=port,
|
port=port,
|
||||||
workers=8,
|
workers=1,
|
||||||
ws="websockets"
|
ws="websockets",
|
||||||
|
reload=False
|
||||||
)
|
)
|
@ -1,30 +1,41 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 请求模型(前端传参校验)
|
# 请求模型(前端传参校验)- 保留update的eigenvalue(如需更新特征值)
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
class FaceCreateRequest(BaseModel):
|
class FaceCreateRequest(BaseModel):
|
||||||
"""创建人脸记录请求模型(无需ID、由数据库自增)"""
|
"""创建人脸记录请求模型(无需ID、由数据库自增)"""
|
||||||
name: str = Field(None, max_length=255, description="名称(可选、最长255字符)")
|
name: Optional[str] = Field(None, max_length=255, description="名称(可选、最长255字符)")
|
||||||
|
|
||||||
|
|
||||||
class FaceUpdateRequest(BaseModel):
|
class FaceUpdateRequest(BaseModel):
|
||||||
"""更新人脸记录请求模型(不变)"""
|
"""更新人脸记录请求模型 - 保留eigenvalue(如需更新特征值,不影响返回)"""
|
||||||
name: str = Field(None, max_length=255, description="名称")
|
name: Optional[str] = Field(None, max_length=255, description="名称(可选)")
|
||||||
eigenvalue: str = Field(None, max_length=255, description="特征(文件处理后可更新)")
|
eigenvalue: Optional[str] = Field(None, description="特征值(可选,文件处理后可更新)") # 保留更新能力
|
||||||
|
address: Optional[str] = Field(None, description="图片完整路径(可选,更新图片时使用)")
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 响应模型(后端返回数据)
|
# 响应模型(后端返回数据)- 核心修改:删除eigenvalue字段
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
class FaceResponse(BaseModel):
|
class FaceResponse(BaseModel):
|
||||||
"""人脸记录响应模型(仍包含ID、由数据库生成后返回)"""
|
"""人脸记录响应模型(仅返回需要的字段,移除eigenvalue)"""
|
||||||
id: int = Field(..., description="主键ID(数据库自增)")
|
id: int = Field(..., description="主键ID(数据库自增)")
|
||||||
name: str = Field(None, description="名称")
|
name: Optional[str] = Field(None, description="名称")
|
||||||
eigenvalue: str | None = Field(None, description="特征(可为空)")
|
address: Optional[str] = Field(None, description="人脸图片完整保存路径(数据库新增字段)") # 仅保留address
|
||||||
created_at: datetime = Field(..., description="记录创建时间")
|
created_at: datetime = Field(..., description="记录创建时间(数据库自动生成)")
|
||||||
updated_at: datetime = Field(..., description="记录更新时间")
|
updated_at: datetime = Field(..., description="记录更新时间(数据库自动生成)")
|
||||||
|
|
||||||
|
# 关键配置:支持从数据库查询结果(字典)直接转换
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class FaceListResponse(BaseModel):
|
||||||
|
"""人脸列表分页响应模型(结构不变,内部FaceResponse已移除eigenvalue)"""
|
||||||
|
total: int = Field(..., description="筛选后的总记录数")
|
||||||
|
faces: List[FaceResponse] = Field(..., description="当前页的人脸记录列表")
|
||||||
|
|
||||||
model_config = {"from_attributes": True}
|
model_config = {"from_attributes": True}
|
37
schema/model_schema.py
Normal file
37
schema/model_schema.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
# 请求模型
|
||||||
|
class ModelCreateRequest(BaseModel):
|
||||||
|
name: str = Field(..., max_length=255, description="模型名称(必填,如:yolo-v8s-car)")
|
||||||
|
description: Optional[str] = Field(None, description="模型描述(可选)")
|
||||||
|
is_default: Optional[bool] = Field(False, description="是否设为默认模型")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelUpdateRequest(BaseModel):
|
||||||
|
name: Optional[str] = Field(None, max_length=255, description="模型名称(可选修改)")
|
||||||
|
description: Optional[str] = Field(None, description="模型描述(可选修改)")
|
||||||
|
is_default: Optional[bool] = Field(None, description="是否设为默认模型(可选切换)")
|
||||||
|
|
||||||
|
|
||||||
|
# 响应模型
|
||||||
|
class ModelResponse(BaseModel):
|
||||||
|
id: int = Field(..., description="模型ID")
|
||||||
|
name: str = Field(..., description="模型名称")
|
||||||
|
path: str = Field(..., description="模型文件相对路径")
|
||||||
|
is_default: bool = Field(..., description="是否默认模型")
|
||||||
|
description: Optional[str] = Field(None, description="模型描述")
|
||||||
|
file_size: Optional[int] = Field(None, description="文件大小(字节)")
|
||||||
|
created_at: datetime = Field(..., description="创建时间")
|
||||||
|
updated_at: datetime = Field(..., description="更新时间")
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class ModelListResponse(BaseModel):
|
||||||
|
total: int = Field(..., description="总记录数")
|
||||||
|
models: List[ModelResponse] = Field(..., description="当前页模型列表")
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
@ -1,6 +1,6 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@ -30,3 +30,11 @@ class UserResponse(BaseModel):
|
|||||||
|
|
||||||
# Pydantic V2 配置(支持从数据库查询结果转换)
|
# Pydantic V2 配置(支持从数据库查询结果转换)
|
||||||
model_config = {"from_attributes": True}
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class UserListResponse(BaseModel):
|
||||||
|
"""用户列表分页响应模型(与设备/人脸列表结构对齐)"""
|
||||||
|
total: int = Field(..., description="用户总数")
|
||||||
|
users: List[UserResponse] = Field(..., description="当前页用户列表")
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
@ -1,162 +1,140 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
|
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, Request
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
from mysql.connector import Error as MySQLError
|
from mysql.connector import Error as MySQLError
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from ds.db import db
|
from ds.db import db
|
||||||
from schema.face_schema import FaceCreateRequest, FaceUpdateRequest, FaceResponse
|
from schema.face_schema import (
|
||||||
from schema.response_schema import APIResponse
|
FaceCreateRequest,
|
||||||
from middle.auth_middleware import get_current_user
|
FaceUpdateRequest,
|
||||||
from schema.user_schema import UserResponse
|
FaceResponse,
|
||||||
|
FaceListResponse
|
||||||
from util.face_util import add_binary_data,get_average_feature
|
|
||||||
#初始化实例
|
|
||||||
|
|
||||||
router = APIRouter(
|
|
||||||
prefix="/faces",
|
|
||||||
tags=["人脸管理"]
|
|
||||||
)
|
)
|
||||||
|
from schema.response_schema import APIResponse
|
||||||
|
|
||||||
|
from util.face_util import add_binary_data, get_average_feature
|
||||||
|
from util.file_util import save_face_to_up_images
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/faces", tags=["人脸管理"])
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 1. 创建人脸记录(核心修正: ID 数据库自增、前端无需传)
|
# 1. 创建人脸记录(使用修复后的路径)
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件、ID自增)")
|
@router.post("", response_model=APIResponse, summary="创建人脸记录")
|
||||||
async def create_face(
|
async def create_face(
|
||||||
# 前端仅需传: name(可选、Form格式)、file(必传、文件)
|
request: Request,
|
||||||
name: str = Form(None, max_length=255, description="名称(可选)"),
|
name: str = Form(None, max_length=255, description="名称(可选)"),
|
||||||
file: UploadFile = File(..., description="人脸文件(必传、暂不处理内容)")
|
file: UploadFile = File(..., description="人脸文件(必传)")
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
创建人脸记录:
|
|
||||||
- 需登录认证
|
|
||||||
- 前端传参: multipart/form-data 表单(name 可选、file 必传)
|
|
||||||
- ID 由数据库自动生成、无需前端传入
|
|
||||||
- 暂不处理文件内容、eigenvalue 设为 None
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 调用你的方法
|
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
try:
|
try:
|
||||||
# 1. 用模型校验 name(仅校验长度、无需ID)
|
|
||||||
face_create = FaceCreateRequest(name=name)
|
face_create = FaceCreateRequest(name=name)
|
||||||
|
client_ip = request.client.host if request.client else ""
|
||||||
|
if not client_ip:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取客户端IP")
|
||||||
|
|
||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
# 把文件转为二进制数组
|
# 读取图片并保存(使用修复后的路径逻辑)
|
||||||
file_content = await file.read()
|
file_content = await file.read()
|
||||||
|
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else "jpg"
|
||||||
# 计算特征值
|
save_result = save_face_to_up_images(
|
||||||
flag, eigenvalue = add_binary_data(file_content)
|
client_ip=client_ip,
|
||||||
|
face_name=name,
|
||||||
if flag == False:
|
image_bytes=file_content,
|
||||||
raise HTTPException(
|
image_format=file_ext
|
||||||
status_code=500,
|
|
||||||
detail="未检测到人脸"
|
|
||||||
)
|
)
|
||||||
|
if not save_result["success"]:
|
||||||
|
raise HTTPException(status_code=500, detail=f"图片保存失败:{save_result['msg']}")
|
||||||
|
db_image_path = save_result["db_path"] # 从修复后的方法获取路径
|
||||||
|
|
||||||
# 打印数组长度
|
# 提取人脸特征
|
||||||
print(f"文件大小: {len(file_content)} 字节")
|
detect_success, detect_result = add_binary_data(file_content)
|
||||||
|
if not detect_success:
|
||||||
|
raise HTTPException(status_code=400, detail=f"人脸检测失败:{detect_result}")
|
||||||
|
eigenvalue = detect_result
|
||||||
|
|
||||||
# 2. 插入数据库: 无需传 ID(自增)、只传 name 和 eigenvalue(None)
|
# 插入数据库
|
||||||
insert_query = """
|
insert_query = """
|
||||||
INSERT INTO face (name, eigenvalue)
|
INSERT INTO face (name, eigenvalue, address)
|
||||||
VALUES (%s, %s)
|
VALUES (%s, %s, %s)
|
||||||
"""
|
"""
|
||||||
cursor.execute(insert_query, (face_create.name, str(eigenvalue)))
|
cursor.execute(insert_query, (face_create.name, str(eigenvalue), db_image_path))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
# 3. 获取数据库自动生成的 ID(关键: 用 LAST_INSERT_ID() 查刚插入的记录)
|
# 查询新记录
|
||||||
select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()"
|
cursor.execute("""
|
||||||
cursor.execute(select_new_query)
|
SELECT id, name, address, created_at, updated_at
|
||||||
|
FROM face
|
||||||
|
WHERE id = LAST_INSERT_ID()
|
||||||
|
""")
|
||||||
created_face = cursor.fetchone()
|
created_face = cursor.fetchone()
|
||||||
|
|
||||||
if not created_face:
|
if not created_face:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=500, detail="创建成功但无法获取记录")
|
||||||
status_code=500,
|
|
||||||
detail="创建人脸记录成功、但无法获取新创建的记录"
|
|
||||||
)
|
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=201,
|
code=201,
|
||||||
message=f"人脸记录创建成功(ID: {created_face['id']}、文件名: {file.filename})",
|
message=f"人脸记录创建成功(ID: {created_face['id']})",
|
||||||
data=FaceResponse(** created_face)
|
data=FaceResponse(**created_face)
|
||||||
)
|
)
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
# 改为使用HTTPException
|
raise HTTPException(status_code=500, detail=f"创建失败: {str(e)}") from e
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail=f"创建人脸记录失败: {str(e)}"
|
|
||||||
) from e
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 捕获其他可能的异常
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail=f"服务器错误: {str(e)}"
|
|
||||||
) from e
|
|
||||||
finally:
|
finally:
|
||||||
await file.close() # 关闭文件流
|
await file.close()
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
# 调用人脸识别得到特征值(这里可以添加你的人脸识别逻辑)
|
|
||||||
flag, eigenvalue = add_binary_data(file_content)
|
|
||||||
if flag == False:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="未检测到人脸"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 将 eigenvalue 转为 str
|
|
||||||
eigenvalue = str(eigenvalue)
|
|
||||||
|
|
||||||
|
# 其他接口(获取单条/列表、更新、删除、获取图片)与之前一致,无需修改
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 2. 获取单个人脸记录(不变、用自增ID查询)
|
# 2. 获取单个人脸记录
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
|
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
|
||||||
async def get_face(
|
async def get_face(face_id: int):
|
||||||
face_id: int, # 这里的 ID 是数据库自增的、前端从创建响应中获取
|
|
||||||
current_user: UserResponse = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
try:
|
try:
|
||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
query = "SELECT * FROM face WHERE id = %s"
|
query = """
|
||||||
|
SELECT id, name, address, created_at, updated_at
|
||||||
|
FROM face
|
||||||
|
WHERE id = %s
|
||||||
|
"""
|
||||||
cursor.execute(query, (face_id,))
|
cursor.execute(query, (face_id,))
|
||||||
face = cursor.fetchone()
|
face = cursor.fetchone()
|
||||||
|
|
||||||
if not face:
|
if not face:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
|
||||||
status_code=404,
|
|
||||||
detail=f"ID为 {face_id} 的人脸记录不存在"
|
|
||||||
)
|
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=200,
|
code=200,
|
||||||
message="人脸记录查询成功",
|
message="查询成功",
|
||||||
data=FaceResponse(**face)
|
data=FaceResponse(**face)
|
||||||
)
|
)
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
# 改为使用HTTPException
|
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail=f"查询人脸记录失败: {str(e)}"
|
|
||||||
) from e
|
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# 后续 3.获取所有、4.更新、5.删除 接口(修复异常处理)
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 3. 获取所有人脸记录(不变)
|
# 3. 获取人脸列表
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@router.get("", response_model=APIResponse, summary="获取所有人脸记录")
|
@router.get("", response_model=APIResponse, summary="获取人脸列表(分页+筛选)")
|
||||||
async def get_all_faces(
|
async def get_face_list(
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
page_size: int = Query(10, ge=1, le=100),
|
||||||
|
name: str = Query(None),
|
||||||
|
has_eigenvalue: bool = Query(None)
|
||||||
):
|
):
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
@ -164,50 +142,66 @@ async def get_all_faces(
|
|||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
query = "SELECT * FROM face ORDER BY id" # 按自增ID排序
|
where_clause = []
|
||||||
cursor.execute(query)
|
params = []
|
||||||
faces = cursor.fetchall()
|
if name:
|
||||||
|
where_clause.append("name LIKE %s")
|
||||||
|
params.append(f"%{name}%")
|
||||||
|
if has_eigenvalue is not None:
|
||||||
|
where_clause.append("eigenvalue IS NOT NULL" if has_eigenvalue else "eigenvalue IS NULL")
|
||||||
|
|
||||||
|
# 总记录数
|
||||||
|
count_query = "SELECT COUNT(*) AS total FROM face"
|
||||||
|
if where_clause:
|
||||||
|
count_query += " WHERE " + " AND ".join(where_clause)
|
||||||
|
cursor.execute(count_query, params)
|
||||||
|
total = cursor.fetchone()["total"]
|
||||||
|
|
||||||
|
# 列表数据
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
list_query = """
|
||||||
|
SELECT id, name, address, created_at, updated_at
|
||||||
|
FROM face
|
||||||
|
"""
|
||||||
|
if where_clause:
|
||||||
|
list_query += " WHERE " + " AND ".join(where_clause)
|
||||||
|
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
|
||||||
|
params.extend([page_size, offset])
|
||||||
|
|
||||||
|
cursor.execute(list_query, params)
|
||||||
|
face_list = cursor.fetchall()
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=200,
|
code=200,
|
||||||
message="所有人脸记录查询成功",
|
message=f"获取成功(共{total}条)",
|
||||||
data=[FaceResponse(** face) for face in faces]
|
data=FaceListResponse(
|
||||||
|
total=total,
|
||||||
|
faces=[FaceResponse(**face) for face in face_list]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
|
||||||
status_code=500,
|
|
||||||
detail=f"查询所有人脸记录失败: {str(e)}"
|
|
||||||
) from e
|
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 4. 更新人脸记录(不变、用自增ID更新)
|
# 4. 更新人脸记录
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录")
|
@router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录")
|
||||||
async def update_face(
|
async def update_face(face_id: int, face_update: FaceUpdateRequest):
|
||||||
face_id: int,
|
|
||||||
face_update: FaceUpdateRequest,
|
|
||||||
current_user: UserResponse = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
try:
|
try:
|
||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
# 检查记录是否存在
|
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
|
||||||
check_query = "SELECT id FROM face WHERE id = %s"
|
exist_face = cursor.fetchone()
|
||||||
cursor.execute(check_query, (face_id,))
|
if not exist_face:
|
||||||
existing_face = cursor.fetchone()
|
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
|
||||||
if not existing_face:
|
old_db_path = exist_face["address"]
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"ID为 {face_id} 的人脸记录不存在"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建更新语句
|
|
||||||
update_fields = []
|
update_fields = []
|
||||||
params = []
|
params = []
|
||||||
if face_update.name is not None:
|
if face_update.name is not None:
|
||||||
@ -216,6 +210,18 @@ async def update_face(
|
|||||||
if face_update.eigenvalue is not None:
|
if face_update.eigenvalue is not None:
|
||||||
update_fields.append("eigenvalue = %s")
|
update_fields.append("eigenvalue = %s")
|
||||||
params.append(face_update.eigenvalue)
|
params.append(face_update.eigenvalue)
|
||||||
|
if face_update.address is not None:
|
||||||
|
# 删除旧图片(相对路径转绝对路径)
|
||||||
|
if old_db_path:
|
||||||
|
old_abs_path = Path(old_db_path).resolve()
|
||||||
|
if old_abs_path.exists():
|
||||||
|
try:
|
||||||
|
old_abs_path.unlink() # 使用Path方法删除更安全
|
||||||
|
print(f"[FaceRouter] 已删除旧图片:{old_abs_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[FaceRouter] 删除旧图片失败:{str(e)}")
|
||||||
|
update_fields.append("address = %s")
|
||||||
|
params.append(face_update.address)
|
||||||
|
|
||||||
if not update_fields:
|
if not update_fields:
|
||||||
raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
|
raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
|
||||||
@ -225,117 +231,143 @@ async def update_face(
|
|||||||
cursor.execute(update_query, params)
|
cursor.execute(update_query, params)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
# 查询更新后记录
|
cursor.execute("""
|
||||||
select_query = "SELECT * FROM face WHERE id = %s"
|
SELECT id, name, address, created_at, updated_at
|
||||||
cursor.execute(select_query, (face_id,))
|
FROM face
|
||||||
|
WHERE id = %s
|
||||||
|
""", (face_id,))
|
||||||
updated_face = cursor.fetchone()
|
updated_face = cursor.fetchone()
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=200,
|
code=200,
|
||||||
message="人脸记录更新成功",
|
message="更新成功",
|
||||||
data=FaceResponse(**updated_face)
|
data=FaceResponse(**updated_face)
|
||||||
)
|
)
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=500, detail=f"更新失败: {str(e)}") from e
|
||||||
status_code=500,
|
|
||||||
detail=f"更新人脸记录失败: {str(e)}"
|
|
||||||
) from e
|
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 5. 删除人脸记录(不变、用自增ID删除)
|
# 5. 删除人脸记录
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
|
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
|
||||||
async def delete_face(
|
async def delete_face(face_id: int):
|
||||||
face_id: int,
|
|
||||||
current_user: UserResponse = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
try:
|
try:
|
||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
check_query = "SELECT id FROM face WHERE id = %s"
|
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
|
||||||
cursor.execute(check_query, (face_id,))
|
exist_face = cursor.fetchone()
|
||||||
existing_face = cursor.fetchone()
|
if not exist_face:
|
||||||
if not existing_face:
|
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
|
||||||
raise HTTPException(
|
old_db_path = exist_face["address"]
|
||||||
status_code=404,
|
|
||||||
detail=f"ID为 {face_id} 的人脸记录不存在"
|
|
||||||
)
|
|
||||||
|
|
||||||
delete_query = "DELETE FROM face WHERE id = %s"
|
cursor.execute("DELETE FROM face WHERE id = %s", (face_id,))
|
||||||
cursor.execute(delete_query, (face_id,))
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
|
# 删除图片
|
||||||
|
extra_msg = ""
|
||||||
|
if old_db_path:
|
||||||
|
old_abs_path = Path(old_db_path).resolve()
|
||||||
|
if old_abs_path.exists():
|
||||||
|
try:
|
||||||
|
old_abs_path.unlink()
|
||||||
|
print(f"[FaceRouter] 已删除图片:{old_abs_path}")
|
||||||
|
extra_msg = "(已同步删除图片)"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[FaceRouter] 删除图片失败:{str(e)}")
|
||||||
|
extra_msg = "(图片删除失败)"
|
||||||
|
else:
|
||||||
|
extra_msg = "(图片不存在)"
|
||||||
|
else:
|
||||||
|
extra_msg = "(无关联图片)"
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=200,
|
code=200,
|
||||||
message=f"ID为 {face_id} 的人脸记录删除成功",
|
message=f"ID为 {face_id} 的记录删除成功 {extra_msg}",
|
||||||
data=None
|
data=None
|
||||||
)
|
)
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}") from e
|
||||||
status_code=500,
|
|
||||||
detail=f"删除人脸记录失败: {str(e)}"
|
|
||||||
) from e
|
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
def get_all_face_name_with_eigenvalue() -> dict:
|
# ------------------------------
|
||||||
"""
|
# 6. 获取人脸图片
|
||||||
获取所有人脸的名称及其对应的特征值、组成字典返回
|
# ------------------------------
|
||||||
key: 人脸名称(name)
|
@router.get("/{face_id}/image", summary="获取人脸图片")
|
||||||
value: 人脸特征值(eigenvalue)、若名称重复则返回平均特征值
|
async def get_face_image(face_id: int):
|
||||||
注: 过滤掉name为None的记录、避免字典key为None的情况
|
conn = None
|
||||||
"""
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
query = "SELECT address, name FROM face WHERE id = %s"
|
||||||
|
cursor.execute(query, (face_id,))
|
||||||
|
face = cursor.fetchone()
|
||||||
|
|
||||||
|
if not face:
|
||||||
|
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
|
||||||
|
|
||||||
|
db_path = face["address"]
|
||||||
|
abs_path = Path(db_path).resolve() # 转为绝对路径
|
||||||
|
if not db_path or not abs_path.exists():
|
||||||
|
raise HTTPException(status_code=404, detail=f"图片不存在(路径:{db_path})")
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
path=abs_path,
|
||||||
|
filename=f"face_{face_id}_{face['name'] or '未命名'}.{db_path.split('.')[-1]}",
|
||||||
|
media_type=f"image/{db_path.split('.')[-1]}"
|
||||||
|
)
|
||||||
|
except MySQLError as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取图片失败: {str(e)}") from e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 内部工具方法
|
||||||
|
# ------------------------------
|
||||||
|
def get_all_face_name_with_eigenvalue() -> dict:
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
try:
|
try:
|
||||||
# 1. 建立数据库连接并获取游标(dictionary=True使结果以字典形式返回)
|
|
||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
# 2. 执行SQL查询: 只获取name非空的记录、减少数据传输
|
|
||||||
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
|
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
faces = cursor.fetchall() # 返回结果: 列表套字典、如 [{"name":"张三","eigenvalue":...}, ...]
|
faces = cursor.fetchall()
|
||||||
|
|
||||||
# 3. 收集同一名称对应的所有特征值(处理名称重复场景)
|
|
||||||
name_to_eigenvalues = {}
|
name_to_eigenvalues = {}
|
||||||
for face in faces:
|
for face in faces:
|
||||||
name = face["name"]
|
name = face["name"]
|
||||||
eigenvalue = face["eigenvalue"]
|
eigenvalue = face["eigenvalue"]
|
||||||
# 若名称已存在、追加特征值;否则新建列表存储
|
|
||||||
if name in name_to_eigenvalues:
|
if name in name_to_eigenvalues:
|
||||||
name_to_eigenvalues[name].append(eigenvalue)
|
name_to_eigenvalues[name].append(eigenvalue)
|
||||||
else:
|
else:
|
||||||
name_to_eigenvalues[name] = [eigenvalue]
|
name_to_eigenvalues[name] = [eigenvalue]
|
||||||
|
|
||||||
# 4. 构建最终字典: 重复名称取平均、唯一名称直接取特征值
|
|
||||||
face_dict = {}
|
face_dict = {}
|
||||||
for name, eigenvalues in name_to_eigenvalues.items():
|
for name, eigenvalues in name_to_eigenvalues.items():
|
||||||
|
|
||||||
# 处理特征值: 多个则求平均、单个则直接使用
|
|
||||||
if len(eigenvalues) > 1:
|
if len(eigenvalues) > 1:
|
||||||
# 调用外部方法计算平均特征值(需确保binary_face_feature_handler已正确导入)
|
|
||||||
face_dict[name] = get_average_feature(eigenvalues)
|
face_dict[name] = get_average_feature(eigenvalues)
|
||||||
else:
|
else:
|
||||||
# 取列表中唯一的特征值(避免value为列表类型)
|
|
||||||
face_dict[name] = eigenvalues[0]
|
face_dict[name] = eigenvalues[0]
|
||||||
|
|
||||||
return face_dict
|
return face_dict
|
||||||
|
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
# 捕获数据库异常、添加上下文信息后重新抛出(便于定位问题)
|
raise Exception(f"获取人脸特征失败: {str(e)}") from e
|
||||||
raise Exception(f"获取人脸名称与特征值失败: {str(e)}") from e
|
|
||||||
finally:
|
finally:
|
||||||
# 5. 无论是否异常、均释放数据库连接和游标(避免资源泄漏)
|
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
497
service/model_service.py
Normal file
497
service/model_service.py
Normal file
@ -0,0 +1,497 @@
|
|||||||
|
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from mysql.connector import Error as MySQLError
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# 复用项目依赖
|
||||||
|
from ds.db import db
|
||||||
|
from schema.model_schema import (
|
||||||
|
ModelCreateRequest,
|
||||||
|
ModelUpdateRequest,
|
||||||
|
ModelResponse,
|
||||||
|
ModelListResponse
|
||||||
|
)
|
||||||
|
from schema.response_schema import APIResponse
|
||||||
|
from util.model_util import load_yolo_model # 使用修复后的模型加载工具
|
||||||
|
|
||||||
|
# 路径配置
|
||||||
|
CURRENT_FILE_PATH = Path(__file__).resolve()
|
||||||
|
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent
|
||||||
|
MODEL_SAVE_ROOT = PROJECT_ROOT / "resource" / "models"
|
||||||
|
MODEL_SAVE_ROOT.mkdir(exist_ok=True, parents=True)
|
||||||
|
DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep
|
||||||
|
|
||||||
|
# 模型限制
|
||||||
|
ALLOWED_MODEL_EXT = {"pt"}
|
||||||
|
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
|
||||||
|
|
||||||
|
# 全局模型变量
|
||||||
|
global _yolo_model
|
||||||
|
_yolo_model = None
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/models", tags=["模型管理"])
|
||||||
|
|
||||||
|
|
||||||
|
# 工具函数:验证模型路径
|
||||||
|
def get_valid_model_abs_path(relative_path: str) -> str:
|
||||||
|
try:
|
||||||
|
relative_path = relative_path.replace("/", os.sep)
|
||||||
|
model_abs_path = PROJECT_ROOT / relative_path
|
||||||
|
model_abs_path = model_abs_path.resolve()
|
||||||
|
model_abs_path_str = str(model_abs_path)
|
||||||
|
|
||||||
|
if not model_abs_path_str.startswith(str(MODEL_SAVE_ROOT)):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"模型路径非法!允许目录:{str(MODEL_SAVE_ROOT)},当前路径:{model_abs_path_str}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not model_abs_path.exists():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"模型文件不存在!路径:{model_abs_path_str}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not model_abs_path.is_file():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"路径不是文件!路径:{model_abs_path_str}"
|
||||||
|
)
|
||||||
|
|
||||||
|
file_size = model_abs_path.stat().st_size
|
||||||
|
if file_size > MAX_MODEL_SIZE:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"模型文件过大({file_size // 1024 // 1024}MB),超过限制{MAX_MODEL_SIZE // 1024 // 1024}MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
file_ext = model_abs_path.suffix.lower()
|
||||||
|
if file_ext not in [f".{ext}" for ext in ALLOWED_MODEL_EXT]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"模型格式非法(仅支持{ALLOWED_MODEL_EXT})!当前格式:{file_ext}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[模型路径校验] 成功!路径:{model_abs_path_str},大小:{file_size // 1024}KB")
|
||||||
|
return model_abs_path_str
|
||||||
|
|
||||||
|
except HTTPException as e:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"路径处理失败:{str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
# 1. 上传模型
|
||||||
|
@router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)")
|
||||||
|
async def upload_model(
|
||||||
|
name: str = Form(..., description="模型名称"),
|
||||||
|
description: str = Form(None, description="模型描述"),
|
||||||
|
is_default: bool = Form(False, description="是否设为默认模型"),
|
||||||
|
file: UploadFile = File(..., description=f"YOLO模型文件(.pt,最大{MAX_MODEL_SIZE // 1024 // 1024}MB)")
|
||||||
|
):
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
saved_file_path = None
|
||||||
|
try:
|
||||||
|
# 校验文件
|
||||||
|
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else ""
|
||||||
|
if file_ext not in ALLOWED_MODEL_EXT:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"仅支持{ALLOWED_MODEL_EXT}格式,当前:{file_ext}"
|
||||||
|
)
|
||||||
|
if file.size > MAX_MODEL_SIZE:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB,当前{file.size // 1024 // 1024}MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存文件
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
|
safe_filename = f"model_{timestamp}_{file.filename.replace(' ', '_')}"
|
||||||
|
saved_file_path = MODEL_SAVE_ROOT / safe_filename
|
||||||
|
with open(saved_file_path, "wb") as f:
|
||||||
|
shutil.copyfileobj(file.file, f)
|
||||||
|
saved_file_path.chmod(0o644) # 设置权限
|
||||||
|
|
||||||
|
# 数据库路径处理
|
||||||
|
db_relative_path = str(saved_file_path).replace(DB_PATH_PREFIX_TO_REMOVE, "").replace(os.sep, "/")
|
||||||
|
|
||||||
|
# 数据库操作
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
if is_default:
|
||||||
|
cursor.execute("UPDATE model SET is_default = 0")
|
||||||
|
|
||||||
|
insert_sql = """
|
||||||
|
INSERT INTO model (name, path, is_default, description, file_size)
|
||||||
|
VALUES (%s, %s, %s, %s, %s)
|
||||||
|
"""
|
||||||
|
cursor.execute(insert_sql, (name, db_relative_path, 1 if is_default else 0, description, file.size))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()")
|
||||||
|
new_model = cursor.fetchone()
|
||||||
|
if not new_model:
|
||||||
|
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
|
||||||
|
|
||||||
|
# 加载默认模型
|
||||||
|
global _yolo_model
|
||||||
|
if is_default:
|
||||||
|
valid_abs_path = get_valid_model_abs_path(db_relative_path)
|
||||||
|
_yolo_model = load_yolo_model(valid_abs_path)
|
||||||
|
if not _yolo_model:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return APIResponse(
|
||||||
|
code=201,
|
||||||
|
message=f"模型上传成功!ID:{new_model['id']}",
|
||||||
|
data=ModelResponse(**new_model)
|
||||||
|
)
|
||||||
|
|
||||||
|
except MySQLError as e:
|
||||||
|
if conn:
|
||||||
|
conn.rollback()
|
||||||
|
if saved_file_path and saved_file_path.exists():
|
||||||
|
saved_file_path.unlink()
|
||||||
|
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||||
|
except Exception as e:
|
||||||
|
if saved_file_path and saved_file_path.exists():
|
||||||
|
saved_file_path.unlink()
|
||||||
|
raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e
|
||||||
|
finally:
|
||||||
|
await file.close()
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# 2. 获取模型列表
|
||||||
|
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
|
||||||
|
async def get_model_list(
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
page_size: int = Query(10, ge=1, le=100),
|
||||||
|
name: str = Query(None),
|
||||||
|
is_default: bool = Query(None)
|
||||||
|
):
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
where_clause = []
|
||||||
|
params = []
|
||||||
|
if name:
|
||||||
|
where_clause.append("name LIKE %s")
|
||||||
|
params.append(f"%{name}%")
|
||||||
|
if is_default is not None:
|
||||||
|
where_clause.append("is_default = %s")
|
||||||
|
params.append(1 if is_default else 0)
|
||||||
|
|
||||||
|
# 总记录数
|
||||||
|
count_sql = "SELECT COUNT(*) AS total FROM model"
|
||||||
|
if where_clause:
|
||||||
|
count_sql += " WHERE " + " AND ".join(where_clause)
|
||||||
|
cursor.execute(count_sql, params)
|
||||||
|
total = cursor.fetchone()["total"]
|
||||||
|
|
||||||
|
# 分页数据
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
list_sql = "SELECT * FROM model"
|
||||||
|
if where_clause:
|
||||||
|
list_sql += " WHERE " + " AND ".join(where_clause)
|
||||||
|
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
|
||||||
|
params.extend([page_size, offset])
|
||||||
|
|
||||||
|
cursor.execute(list_sql, params)
|
||||||
|
model_list = cursor.fetchall()
|
||||||
|
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message=f"获取成功!共{total}条记录",
|
||||||
|
data=ModelListResponse(
|
||||||
|
total=total,
|
||||||
|
models=[ModelResponse(**model) for model in model_list]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except MySQLError as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# 3. 获取默认模型
|
||||||
|
@router.get("/default", response_model=APIResponse, summary="获取当前默认模型")
|
||||||
|
async def get_default_model():
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
cursor.execute("SELECT * FROM model WHERE is_default = 1")
|
||||||
|
default_model = cursor.fetchone()
|
||||||
|
|
||||||
|
if not default_model:
|
||||||
|
raise HTTPException(status_code=404, detail="暂无默认模型")
|
||||||
|
|
||||||
|
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
||||||
|
global _yolo_model
|
||||||
|
|
||||||
|
if not _yolo_model:
|
||||||
|
_yolo_model = load_yolo_model(valid_abs_path)
|
||||||
|
if not _yolo_model:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"默认模型存在,但加载失败(路径:{valid_abs_path})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message="默认模型查询成功",
|
||||||
|
data=ModelResponse(**default_model)
|
||||||
|
)
|
||||||
|
|
||||||
|
except MySQLError as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# 4. 获取单个模型详情
|
||||||
|
@router.get("/{model_id}", response_model=APIResponse, summary="获取单个模型详情")
|
||||||
|
async def get_model(model_id: int):
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||||
|
model = cursor.fetchone()
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_abs_path = get_valid_model_abs_path(model["path"])
|
||||||
|
except HTTPException as e:
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message=f"查询成功,但路径异常:{e.detail}",
|
||||||
|
data=ModelResponse(**model)
|
||||||
|
)
|
||||||
|
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message="查询成功",
|
||||||
|
data=ModelResponse(**model)
|
||||||
|
)
|
||||||
|
|
||||||
|
except MySQLError as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# 5. 更新模型信息
|
||||||
|
@router.put("/{model_id}", response_model=APIResponse, summary="更新模型信息")
|
||||||
|
async def update_model(model_id: int, model_update: ModelUpdateRequest):
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||||
|
exist_model = cursor.fetchone()
|
||||||
|
if not exist_model:
|
||||||
|
raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}")
|
||||||
|
|
||||||
|
update_fields = []
|
||||||
|
params = []
|
||||||
|
if model_update.name is not None:
|
||||||
|
update_fields.append("name = %s")
|
||||||
|
params.append(model_update.name)
|
||||||
|
if model_update.description is not None:
|
||||||
|
update_fields.append("description = %s")
|
||||||
|
params.append(model_update.description)
|
||||||
|
|
||||||
|
need_load_default = False
|
||||||
|
if model_update.is_default is not None:
|
||||||
|
if model_update.is_default:
|
||||||
|
cursor.execute("UPDATE model SET is_default = 0")
|
||||||
|
update_fields.append("is_default = 1")
|
||||||
|
need_load_default = True
|
||||||
|
else:
|
||||||
|
cursor.execute("SELECT COUNT(*) AS cnt FROM model WHERE is_default = 1")
|
||||||
|
default_count = cursor.fetchone()["cnt"]
|
||||||
|
if default_count == 1 and exist_model["is_default"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="当前是唯一默认模型,不可取消!"
|
||||||
|
)
|
||||||
|
update_fields.append("is_default = 0")
|
||||||
|
|
||||||
|
if not update_fields:
|
||||||
|
raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
|
||||||
|
|
||||||
|
params.append(model_id)
|
||||||
|
update_sql = f"""
|
||||||
|
UPDATE model
|
||||||
|
SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP
|
||||||
|
WHERE id = %s
|
||||||
|
"""
|
||||||
|
cursor.execute(update_sql, params)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||||
|
updated_model = cursor.fetchone()
|
||||||
|
|
||||||
|
global _yolo_model
|
||||||
|
if need_load_default:
|
||||||
|
valid_abs_path = get_valid_model_abs_path(updated_model["path"])
|
||||||
|
_yolo_model = load_yolo_model(valid_abs_path)
|
||||||
|
if not _yolo_model:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message="模型更新成功",
|
||||||
|
data=ModelResponse(**updated_model)
|
||||||
|
)
|
||||||
|
|
||||||
|
except MySQLError as e:
|
||||||
|
if conn:
|
||||||
|
conn.rollback()
|
||||||
|
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# 6. 删除模型
|
||||||
|
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
|
||||||
|
async def delete_model(model_id: int):
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||||
|
exist_model = cursor.fetchone()
|
||||||
|
if not exist_model:
|
||||||
|
raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}")
|
||||||
|
if exist_model["is_default"]:
|
||||||
|
raise HTTPException(status_code=400, detail="默认模型不可删除!")
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_abs_path_str = get_valid_model_abs_path(exist_model["path"])
|
||||||
|
model_abs_path = Path(model_abs_path_str)
|
||||||
|
except HTTPException as e:
|
||||||
|
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
|
||||||
|
conn.commit()
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message=f"记录删除成功,文件异常:{e.detail}",
|
||||||
|
data=None
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
extra_msg = ""
|
||||||
|
try:
|
||||||
|
model_abs_path.unlink()
|
||||||
|
extra_msg = f"(已删除文件)"
|
||||||
|
except Exception as e:
|
||||||
|
extra_msg = f"(文件删除失败:{str(e)})"
|
||||||
|
|
||||||
|
global _yolo_model
|
||||||
|
if _yolo_model and str(_yolo_model.model_path) == model_abs_path_str:
|
||||||
|
_yolo_model = None
|
||||||
|
print(f"[模型删除] 已清空全局模型(路径:{model_abs_path_str})")
|
||||||
|
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message=f"模型删除成功!ID:{model_id} {extra_msg}",
|
||||||
|
data=None
|
||||||
|
)
|
||||||
|
|
||||||
|
except MySQLError as e:
|
||||||
|
if conn:
|
||||||
|
conn.rollback()
|
||||||
|
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# 7. 下载模型文件
|
||||||
|
@router.get("/{model_id}/download", summary="下载模型文件")
|
||||||
|
async def download_model(model_id: int):
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||||
|
model = cursor.fetchone()
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}")
|
||||||
|
|
||||||
|
valid_abs_path = get_valid_model_abs_path(model["path"])
|
||||||
|
model_abs_path = Path(valid_abs_path)
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
path=model_abs_path,
|
||||||
|
filename=f"model_{model_id}_{model['name']}.pt",
|
||||||
|
media_type="application/octet-stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
except MySQLError as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# 对外提供当前模型
|
||||||
|
def get_current_yolo_model():
|
||||||
|
"""供检测模块获取当前加载的模型"""
|
||||||
|
global _yolo_model
|
||||||
|
if not _yolo_model:
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
cursor.execute("SELECT path FROM model WHERE is_default = 1")
|
||||||
|
default_model = cursor.fetchone()
|
||||||
|
if not default_model:
|
||||||
|
print("[get_current_yolo_model] 暂无默认模型")
|
||||||
|
return None
|
||||||
|
|
||||||
|
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
||||||
|
_yolo_model = load_yolo_model(valid_abs_path)
|
||||||
|
if _yolo_model:
|
||||||
|
print(f"[get_current_yolo_model] 自动加载默认模型成功")
|
||||||
|
else:
|
||||||
|
print(f"[get_current_yolo_model] 自动加载默认模型失败")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[get_current_yolo_model] 加载失败:{str(e)}")
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
return _yolo_model
|
@ -1,6 +1,7 @@
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from mysql.connector import Error as MySQLError
|
from mysql.connector import Error as MySQLError
|
||||||
|
|
||||||
from ds.db import db
|
from ds.db import db
|
||||||
@ -11,7 +12,7 @@ from middle.auth_middleware import (
|
|||||||
verify_password,
|
verify_password,
|
||||||
create_access_token,
|
create_access_token,
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||||
get_current_user
|
get_current_user # 仅保留登录用户校验,移除is_admin导入
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建用户接口路由(前缀 /users、标签用于 Swagger 分类)
|
# 创建用户接口路由(前缀 /users、标签用于 Swagger 分类)
|
||||||
@ -152,3 +153,98 @@ async def get_current_user_info(
|
|||||||
data=current_user
|
data=current_user
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 4. 获取用户列表(仅需登录权限)
|
||||||
|
# ------------------------------
|
||||||
|
@router.get("/list", response_model=APIResponse, summary="获取用户列表")
|
||||||
|
async def get_user_list(
|
||||||
|
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||||
|
page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"),
|
||||||
|
username: Optional[str] = Query(None, description="用户名模糊搜索"),
|
||||||
|
current_user: UserResponse = Depends(get_current_user) # 仅需登录即可访问(移除管理员校验)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取用户列表:
|
||||||
|
- 需登录权限(请求头携带 Token: Bearer <token>)
|
||||||
|
- 支持分页查询(page=页码,page_size=每页条数)
|
||||||
|
- 支持用户名模糊搜索(如输入"test"可匹配"test123"、"admin_test"等)
|
||||||
|
- 仅返回用户ID、用户名、创建时间、更新时间(不包含密码等敏感信息)
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
# 计算分页偏移量(page从1开始,偏移量=(页码-1)*每页条数)
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
# 基础查询(仅查非敏感字段)
|
||||||
|
base_query = """
|
||||||
|
SELECT id, username, created_at, updated_at
|
||||||
|
FROM users
|
||||||
|
"""
|
||||||
|
# 总条数查询(用于分页计算)
|
||||||
|
count_query = "SELECT COUNT(*) as total FROM users"
|
||||||
|
|
||||||
|
# 条件拼接(支持用户名模糊搜索)
|
||||||
|
conditions = []
|
||||||
|
params = []
|
||||||
|
if username:
|
||||||
|
conditions.append("username LIKE %s")
|
||||||
|
params.append(f"%{username}%") # 模糊匹配:%表示任意字符
|
||||||
|
|
||||||
|
# 构建最终查询语句
|
||||||
|
if conditions:
|
||||||
|
where_clause = " WHERE " + " AND ".join(conditions)
|
||||||
|
final_query = f"{base_query}{where_clause} LIMIT %s OFFSET %s"
|
||||||
|
final_count_query = f"{count_query}{where_clause}"
|
||||||
|
params.extend([page_size, offset]) # 追加分页参数
|
||||||
|
else:
|
||||||
|
final_query = f"{base_query} LIMIT %s OFFSET %s"
|
||||||
|
final_count_query = count_query
|
||||||
|
params = [page_size, offset]
|
||||||
|
|
||||||
|
# 1. 查询用户列表数据
|
||||||
|
cursor.execute(final_query, params)
|
||||||
|
users = cursor.fetchall()
|
||||||
|
|
||||||
|
# 2. 查询总条数(用于计算总页数)
|
||||||
|
count_params = [f"%{username}%"] if username else []
|
||||||
|
cursor.execute(final_count_query, count_params)
|
||||||
|
total = cursor.fetchone()["total"]
|
||||||
|
|
||||||
|
# 3. 转换为UserResponse模型(确保字段匹配)
|
||||||
|
user_list = [
|
||||||
|
UserResponse(
|
||||||
|
id=user["id"],
|
||||||
|
username=user["username"],
|
||||||
|
created_at=user["created_at"],
|
||||||
|
updated_at=user["updated_at"]
|
||||||
|
)
|
||||||
|
for user in users
|
||||||
|
]
|
||||||
|
|
||||||
|
# 4. 计算总页数(向上取整,如11条数据每页10条=2页)
|
||||||
|
total_pages = (total + page_size - 1) // page_size
|
||||||
|
|
||||||
|
# 返回结果(包含列表和分页信息)
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message="获取用户列表成功",
|
||||||
|
data={
|
||||||
|
"users": user_list,
|
||||||
|
"pagination": {
|
||||||
|
"page": page, # 当前页码
|
||||||
|
"page_size": page_size, # 每页条数
|
||||||
|
"total": total, # 总数据量
|
||||||
|
"total_pages": total_pages # 总页数
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except MySQLError as e:
|
||||||
|
raise Exception(f"获取用户列表失败: {str(e)}") from e
|
||||||
|
finally:
|
||||||
|
# 无论成功失败,都关闭数据库连接
|
||||||
|
db.close_connection(conn, cursor)
|
@ -4,6 +4,11 @@ import insightface
|
|||||||
from insightface.app import FaceAnalysis
|
from insightface.app import FaceAnalysis
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# 配置日志(便于排查)
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - [FaceUtil] - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# 全局变量存储InsightFace引擎和特征列表
|
# 全局变量存储InsightFace引擎和特征列表
|
||||||
_insightface_app = None
|
_insightface_app = None
|
||||||
@ -11,135 +16,141 @@ _feature_list = []
|
|||||||
|
|
||||||
|
|
||||||
def init_insightface():
|
def init_insightface():
|
||||||
"""初始化InsightFace引擎"""
|
"""初始化InsightFace引擎(确保成功后再使用)"""
|
||||||
global _insightface_app
|
global _insightface_app
|
||||||
try:
|
try:
|
||||||
print("正在初始化InsightFace引擎...")
|
if _insightface_app is not None:
|
||||||
app = FaceAnalysis(name='buffalo_l', root='~/.insightface')
|
logger.info("InsightFace引擎已初始化,无需重复执行")
|
||||||
app.prepare(ctx_id=0, det_size=(640, 640))
|
return _insightface_app
|
||||||
print("InsightFace引擎初始化完成")
|
|
||||||
|
logger.info("正在初始化InsightFace引擎(模型:buffalo_l)...")
|
||||||
|
# 手动指定模型下载路径(避免权限问题,可选)
|
||||||
|
app = FaceAnalysis(
|
||||||
|
name='buffalo_l',
|
||||||
|
root='~/.insightface', # 模型默认下载路径
|
||||||
|
providers=['CPUExecutionProvider'] # 强制用CPU(若有GPU可加'CUDAExecutionProvider')
|
||||||
|
)
|
||||||
|
app.prepare(ctx_id=0, det_size=(640, 640)) # det_size越大,小人脸检测越准
|
||||||
|
logger.info("InsightFace引擎初始化完成")
|
||||||
_insightface_app = app
|
_insightface_app = app
|
||||||
return app
|
return app
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"InsightFace初始化失败: {e}")
|
logger.error(f"InsightFace初始化失败:{str(e)}", exc_info=True) # 打印详细堆栈
|
||||||
|
_insightface_app = None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def add_binary_data(binary_data):
|
def add_binary_data(binary_data):
|
||||||
"""
|
"""
|
||||||
接收单张图片的二进制数据、提取特征并保存
|
接收单张图片的二进制数据、提取特征并保存
|
||||||
|
返回:(True, 特征值numpy数组) 或 (False, 错误信息字符串)
|
||||||
参数:
|
|
||||||
binary_data: 图片的二进制数据(bytes类型)
|
|
||||||
|
|
||||||
返回:
|
|
||||||
成功提取特征时返回 (True, 特征值numpy数组)
|
|
||||||
失败时返回 (False, None)
|
|
||||||
"""
|
"""
|
||||||
global _insightface_app, _feature_list
|
global _insightface_app, _feature_list
|
||||||
|
|
||||||
|
# 1. 先检查引擎是否初始化成功
|
||||||
if not _insightface_app:
|
if not _insightface_app:
|
||||||
print("引擎未初始化、无法处理")
|
init_result = init_insightface() # 尝试重新初始化
|
||||||
return False, None
|
if not init_result:
|
||||||
|
error_msg = "InsightFace引擎未初始化,无法检测人脸"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 直接处理二进制数据: 转换为图像格式
|
# 2. 验证二进制数据有效性
|
||||||
img = Image.open(BytesIO(binary_data))
|
if len(binary_data) < 1024: # 过滤过小的无效图片(小于1KB)
|
||||||
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
error_msg = f"图片过小({len(binary_data)}字节),可能不是有效图片"
|
||||||
|
logger.warning(error_msg)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
# 提取特征
|
# 3. 二进制数据转CV2格式(关键步骤,避免通道错误)
|
||||||
faces = _insightface_app.get(frame)
|
try:
|
||||||
if faces:
|
img = Image.open(BytesIO(binary_data)).convert("RGB") # 强制转RGB
|
||||||
# 获取当前提取的特征值
|
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) # InsightFace需要BGR格式
|
||||||
current_feature = faces[0].embedding
|
|
||||||
# 添加到特征列表
|
|
||||||
_feature_list.append(current_feature)
|
|
||||||
print(f"已累计 {len(_feature_list)} 个特征")
|
|
||||||
# 返回成功标志和当前特征值
|
|
||||||
return True, current_feature
|
|
||||||
else:
|
|
||||||
print("二进制数据中未检测到人脸")
|
|
||||||
return False, None
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"处理二进制数据出错: {e}")
|
error_msg = f"图片格式转换失败:{str(e)}"
|
||||||
return False, None
|
logger.error(error_msg, exc_info=True)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
# 4. 检查图片尺寸(避免极端尺寸导致检测失败)
|
||||||
|
height, width = frame.shape[:2]
|
||||||
|
if height < 64 or width < 64: # 人脸检测最小建议尺寸
|
||||||
|
error_msg = f"图片尺寸过小({width}x{height}),需至少64x64像素"
|
||||||
|
logger.warning(error_msg)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
# 5. 调用InsightFace检测人脸
|
||||||
|
logger.info(f"开始检测人脸(图片尺寸:{width}x{height},格式:BGR)")
|
||||||
|
faces = _insightface_app.get(frame)
|
||||||
|
|
||||||
|
if not faces:
|
||||||
|
error_msg = "未检测到人脸(请确保图片包含清晰正面人脸,无遮挡、不模糊)"
|
||||||
|
logger.warning(error_msg)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
# 6. 提取特征并保存
|
||||||
|
current_feature = faces[0].embedding
|
||||||
|
_feature_list.append(current_feature)
|
||||||
|
logger.info(f"人脸检测成功,提取特征值(维度:{current_feature.shape[0]}),累计特征数:{len(_feature_list)}")
|
||||||
|
return True, current_feature
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"处理图片时发生异常:{str(e)}"
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
|
||||||
|
# 以下函数保持不变(get_average_feature/clear_features/get_feature_list)
|
||||||
def get_average_feature(features=None):
|
def get_average_feature(features=None):
|
||||||
"""
|
|
||||||
计算多个特征向量的平均值
|
|
||||||
|
|
||||||
参数:
|
|
||||||
features: 可选、特征值列表。如果未提供、则使用全局存储的_feature_list
|
|
||||||
每个元素可以是字符串格式或numpy数组
|
|
||||||
|
|
||||||
返回:
|
|
||||||
单一平均特征向量的numpy数组、若无可计算数据则返回None
|
|
||||||
"""
|
|
||||||
global _feature_list
|
global _feature_list
|
||||||
|
try:
|
||||||
# 如果未提供features参数、则使用全局特征列表
|
|
||||||
if features is None:
|
if features is None:
|
||||||
features = _feature_list
|
features = _feature_list
|
||||||
|
|
||||||
try:
|
|
||||||
# 验证输入是否为列表且不为空
|
|
||||||
if not isinstance(features, list) or len(features) == 0:
|
if not isinstance(features, list) or len(features) == 0:
|
||||||
print("输入必须是包含至少一个特征值的列表")
|
logger.warning("输入必须是包含至少一个特征值的列表")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 处理每个特征值
|
|
||||||
processed_features = []
|
processed_features = []
|
||||||
for i, embedding in enumerate(features):
|
for i, embedding in enumerate(features):
|
||||||
try:
|
try:
|
||||||
if isinstance(embedding, str):
|
if isinstance(embedding, str):
|
||||||
# 处理包含括号和逗号的字符串格式
|
|
||||||
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
|
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
|
||||||
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
|
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
|
||||||
embedding_np = np.array(embedding_list, dtype=np.float32)
|
embedding_np = np.array(embedding_list, dtype=np.float32)
|
||||||
else:
|
else:
|
||||||
embedding_np = np.array(embedding, dtype=np.float32)
|
embedding_np = np.array(embedding, dtype=np.float32)
|
||||||
|
|
||||||
# 验证特征值格式
|
|
||||||
if len(embedding_np.shape) == 1:
|
if len(embedding_np.shape) == 1:
|
||||||
processed_features.append(embedding_np)
|
processed_features.append(embedding_np)
|
||||||
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
|
logger.info(f"已添加第 {i + 1} 个特征值用于计算平均值")
|
||||||
else:
|
else:
|
||||||
print(f"跳过第 {i + 1} 个特征值、不是一维数组")
|
logger.warning(f"跳过第 {i + 1} 个特征值:不是一维数组")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"处理第 {i + 1} 个特征值时出错: {e}")
|
logger.error(f"处理第 {i + 1} 个特征值时出错:{str(e)}")
|
||||||
|
|
||||||
# 确保有有效的特征值
|
|
||||||
if not processed_features:
|
if not processed_features:
|
||||||
print("没有有效的特征值用于计算平均值")
|
logger.warning("没有有效的特征值用于计算平均值")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 检查所有特征向量维度是否相同
|
|
||||||
dims = {feat.shape[0] for feat in processed_features}
|
dims = {feat.shape[0] for feat in processed_features}
|
||||||
if len(dims) > 1:
|
if len(dims) > 1:
|
||||||
print(f"特征值维度不一致、无法计算平均值。检测到的维度: {dims}")
|
logger.error(f"特征值维度不一致:{dims},无法计算平均值")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 计算平均值
|
|
||||||
avg_feature = np.mean(processed_features, axis=0)
|
avg_feature = np.mean(processed_features, axis=0)
|
||||||
print(f"成功计算 {len(processed_features)} 个特征值的平均特征向量、维度: {avg_feature.shape[0]}")
|
logger.info(f"计算成功:{len(processed_features)} 个特征值的平均向量(维度:{avg_feature.shape[0]})")
|
||||||
|
|
||||||
return avg_feature
|
return avg_feature
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"计算平均特征值时出错: {e}")
|
logger.error(f"计算平均特征值出错:{str(e)}", exc_info=True)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def clear_features():
|
def clear_features():
|
||||||
"""清空已存储的特征数据"""
|
|
||||||
global _feature_list
|
global _feature_list
|
||||||
_feature_list = []
|
_feature_list = []
|
||||||
print("已清空所有特征数据")
|
logger.info("已清空所有特征数据")
|
||||||
|
|
||||||
|
|
||||||
def get_feature_list():
|
def get_feature_list():
|
||||||
"""获取当前存储的特征列表"""
|
|
||||||
global _feature_list
|
global _feature_list
|
||||||
return _feature_list.copy() # 返回副本防止外部直接修改
|
logger.info(f"当前特征列表长度:{len(_feature_list)}")
|
||||||
|
return _feature_list.copy()
|
83
util/file_util.py
Normal file
83
util/file_util.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
|
def save_face_to_up_images(
|
||||||
|
client_ip: str,
|
||||||
|
face_name: str,
|
||||||
|
image_bytes: bytes,
|
||||||
|
image_format: str = "jpg"
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径
|
||||||
|
修复路径计算错误,确保所有路径在up_images根目录下
|
||||||
|
|
||||||
|
参数:
|
||||||
|
client_ip: 客户端IP(原始格式,如192.168.1.101)
|
||||||
|
face_name: 人脸名称(用户输入,可为空)
|
||||||
|
image_bytes: 人脸图片二进制数据
|
||||||
|
image_format: 图片格式(默认jpg)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
字典:success(是否成功)、db_path(存数据库的相对路径)、local_abs_path(本地绝对路径)、msg(提示)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 1. 基础参数校验
|
||||||
|
if not client_ip.strip():
|
||||||
|
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "客户端IP不能为空"}
|
||||||
|
if not image_bytes:
|
||||||
|
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "图片二进制数据为空"}
|
||||||
|
if image_format.lower() not in ["jpg", "jpeg", "png"]:
|
||||||
|
return {"success": False, "db_path": "", "local_abs_path": "", "msg": "仅支持jpg/jpeg/png格式"}
|
||||||
|
|
||||||
|
# 2. 处理特殊字符(避免路径错误)
|
||||||
|
safe_ip = client_ip.strip().replace(".", "_") # IP中的.替换为_
|
||||||
|
safe_face_name = face_name.strip() if (face_name and face_name.strip()) else "未命名"
|
||||||
|
safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|']) # 过滤非法字符
|
||||||
|
|
||||||
|
# 3. 构建根目录(强制转为绝对路径,避免相对路径混淆)
|
||||||
|
root_dir = Path("up_images").resolve() # 转为绝对路径(关键修复!)
|
||||||
|
if not root_dir.exists():
|
||||||
|
root_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
print(f"[FileUtil] 已创建up_images根目录:{root_dir}")
|
||||||
|
|
||||||
|
# 4. 构建文件层级路径(确保在root_dir子目录下)
|
||||||
|
ip_dir = root_dir / safe_ip
|
||||||
|
face_name_dir = ip_dir / safe_face_name
|
||||||
|
face_name_dir.mkdir(parents=True, exist_ok=True) # 自动创建目录
|
||||||
|
print(f"[FileUtil] 图片存储目录:{face_name_dir}")
|
||||||
|
|
||||||
|
# 5. 生成唯一文件名(毫秒级时间戳)
|
||||||
|
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
|
||||||
|
image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}"
|
||||||
|
|
||||||
|
# 6. 计算路径(关键修复:确保所有路径都是绝对路径且在root_dir下)
|
||||||
|
local_abs_path = face_name_dir / image_filename # 绝对路径
|
||||||
|
|
||||||
|
# 验证路径是否在root_dir下(防止路径穿越攻击)
|
||||||
|
if not local_abs_path.resolve().is_relative_to(root_dir.resolve()):
|
||||||
|
raise Exception(f"图片路径不在up_images根目录下(安全校验失败):{local_abs_path}")
|
||||||
|
|
||||||
|
# 数据库存储路径:从root_dir开始的相对路径(如 up_images/192_168_110_31/小王/xxx.jpg)
|
||||||
|
db_path = str(root_dir.name / local_abs_path.relative_to(root_dir))
|
||||||
|
|
||||||
|
# 7. 写入图片文件
|
||||||
|
with open(local_abs_path, "wb") as f:
|
||||||
|
f.write(image_bytes)
|
||||||
|
print(f"[FileUtil] 图片保存成功:")
|
||||||
|
print(f" 数据库路径:{db_path}")
|
||||||
|
print(f" 本地绝对路径:{local_abs_path}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"db_path": db_path, # 存数据库的相对路径(up_images开头)
|
||||||
|
"local_abs_path": str(local_abs_path), # 本地绝对路径
|
||||||
|
"msg": "图片保存成功"
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"图片保存失败:{str(e)}"
|
||||||
|
print(f"[FileUtil] 错误:{error_msg}")
|
||||||
|
return {"success": False, "db_path": "", "local_abs_path": "", "msg": error_msg}
|
61
util/model_util.py
Normal file
61
util/model_util.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import traceback
|
||||||
|
from ultralytics import YOLO
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def load_yolo_model(model_path: str) -> Optional[YOLO]:
|
||||||
|
"""
|
||||||
|
加载YOLO模型(支持v5/v8),并校验模型有效性
|
||||||
|
:param model_path: 模型文件的绝对路径
|
||||||
|
:return: 加载成功返回YOLO模型实例,失败返回None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 加载前的基础信息检查
|
||||||
|
print(f"\n[模型工具] 开始加载模型:{model_path}")
|
||||||
|
print(f"[模型工具] 文件是否存在:{os.path.exists(model_path)}")
|
||||||
|
if os.path.exists(model_path):
|
||||||
|
print(f"[模型工具] 文件大小:{os.path.getsize(model_path) / 1024 / 1024:.2f} MB")
|
||||||
|
|
||||||
|
# 强制重新加载模型,避免缓存问题
|
||||||
|
model = YOLO(model_path)
|
||||||
|
|
||||||
|
# 兼容性校验:使用numpy空数组测试模型
|
||||||
|
dummy_image = np.zeros((640, 640, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 优先使用新版本参数
|
||||||
|
model.predict(
|
||||||
|
source=dummy_image,
|
||||||
|
imgsz=640,
|
||||||
|
conf=0.25,
|
||||||
|
verbose=False,
|
||||||
|
stream=False
|
||||||
|
)
|
||||||
|
except Exception as pred_e:
|
||||||
|
print(f"[模型工具] 预测校验兼容处理:{str(pred_e)}")
|
||||||
|
# 兼容旧版本YOLO参数
|
||||||
|
model.predict(
|
||||||
|
img=dummy_image,
|
||||||
|
imgsz=640,
|
||||||
|
conf=0.25,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证模型基本属性
|
||||||
|
if not hasattr(model, 'names'):
|
||||||
|
print("[模型工具] 警告:模型缺少类别名称属性")
|
||||||
|
else:
|
||||||
|
print(f"[模型工具] 模型包含类别:{list(model.names.values())[:5]}...") # 显示前5个类别
|
||||||
|
|
||||||
|
print(f"[模型工具] 模型加载成功!")
|
||||||
|
return model
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 详细错误信息输出
|
||||||
|
print(f"\n[模型工具] 加载模型失败!路径:{model_path}")
|
||||||
|
print(f"[模型工具] 异常类型:{type(e).__name__}")
|
||||||
|
print(f"[模型工具] 异常详情:{str(e)}")
|
||||||
|
print(f"[模型工具] 堆栈跟踪:\n{traceback.format_exc()}")
|
||||||
|
return None
|
Reference in New Issue
Block a user