目前可以成功动态更换模型运行的

This commit is contained in:
2025-09-12 14:05:09 +08:00
parent 435b2a0e6c
commit 4be7f7bf14
13 changed files with 1518 additions and 325 deletions

View File

@ -1,162 +1,140 @@
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, Request
from fastapi.responses import FileResponse
from mysql.connector import Error as MySQLError
import os
from pathlib import Path
from ds.db import db
from schema.face_schema import FaceCreateRequest, FaceUpdateRequest, FaceResponse
from schema.response_schema import APIResponse
from middle.auth_middleware import get_current_user
from schema.user_schema import UserResponse
from util.face_util import add_binary_data,get_average_feature
#初始化实例
router = APIRouter(
prefix="/faces",
tags=["人脸管理"]
from schema.face_schema import (
FaceCreateRequest,
FaceUpdateRequest,
FaceResponse,
FaceListResponse
)
from schema.response_schema import APIResponse
from util.face_util import add_binary_data, get_average_feature
from util.file_util import save_face_to_up_images
router = APIRouter(prefix="/faces", tags=["人脸管理"])
# ------------------------------
# 1. 创建人脸记录(核心修正: ID 数据库自增、前端无需传
# 1. 创建人脸记录(使用修复后的路径
# ------------------------------
@router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件、ID自增")
@router.post("", response_model=APIResponse, summary="创建人脸记录")
async def create_face(
# 前端仅需传: name可选、Form格式、file必传、文件
request: Request,
name: str = Form(None, max_length=255, description="名称(可选)"),
file: UploadFile = File(..., description="人脸文件(必传、暂不处理内容")
file: UploadFile = File(..., description="人脸文件(必传)")
):
"""
创建人脸记录:
- 需登录认证
- 前端传参: multipart/form-data 表单name 可选、file 必传)
- ID 由数据库自动生成、无需前端传入
- 暂不处理文件内容、eigenvalue 设为 None
"""
# 调用你的方法
conn = None
cursor = None
try:
# 1. 用模型校验 name仅校验长度、无需ID
face_create = FaceCreateRequest(name=name)
client_ip = request.client.host if request.client else ""
if not client_ip:
raise HTTPException(status_code=400, detail="无法获取客户端IP")
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 把文件转为二进制数组
# 读取图片并保存(使用修复后的路径逻辑)
file_content = await file.read()
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else "jpg"
save_result = save_face_to_up_images(
client_ip=client_ip,
face_name=name,
image_bytes=file_content,
image_format=file_ext
)
if not save_result["success"]:
raise HTTPException(status_code=500, detail=f"图片保存失败:{save_result['msg']}")
db_image_path = save_result["db_path"] # 从修复后的方法获取路径
# 计算特征
flag, eigenvalue = add_binary_data(file_content)
# 提取人脸特征
detect_success, detect_result = add_binary_data(file_content)
if not detect_success:
raise HTTPException(status_code=400, detail=f"人脸检测失败:{detect_result}")
eigenvalue = detect_result
if flag == False:
raise HTTPException(
status_code=500,
detail="未检测到人脸"
)
# 打印数组长度
print(f"文件大小: {len(file_content)} 字节")
# 2. 插入数据库: 无需传 ID自增、只传 name 和 eigenvalueNone
# 插入数据库
insert_query = """
INSERT INTO face (name, eigenvalue)
VALUES (%s, %s)
INSERT INTO face (name, eigenvalue, address)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_query, (face_create.name, str(eigenvalue)))
cursor.execute(insert_query, (face_create.name, str(eigenvalue), db_image_path))
conn.commit()
# 3. 获取数据库自动生成的 ID关键: 用 LAST_INSERT_ID() 查刚插入的记录
select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()"
cursor.execute(select_new_query)
# 查询新记录
cursor.execute("""
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = LAST_INSERT_ID()
""")
created_face = cursor.fetchone()
if not created_face:
raise HTTPException(
status_code=500,
detail="创建人脸记录成功、但无法获取新创建的记录"
)
raise HTTPException(status_code=500, detail="创建成功但无法获取记录")
return APIResponse(
code=201,
message=f"人脸记录创建成功ID: {created_face['id']}、文件名: {file.filename}",
data=FaceResponse(** created_face)
message=f"人脸记录创建成功ID: {created_face['id']}",
data=FaceResponse(**created_face)
)
except MySQLError as e:
if conn:
conn.rollback()
# 改为使用HTTPException
raise HTTPException(
status_code=500,
detail=f"创建人脸记录失败: {str(e)}"
) from e
raise HTTPException(status_code=500, detail=f"创建失败: {str(e)}") from e
except Exception as e:
# 捕获其他可能的异常
raise HTTPException(
status_code=500,
detail=f"服务器错误: {str(e)}"
) from e
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
finally:
await file.close() # 关闭文件流
await file.close()
db.close_connection(conn, cursor)
# 调用人脸识别得到特征值(这里可以添加你的人脸识别逻辑)
flag, eigenvalue = add_binary_data(file_content)
if flag == False:
raise HTTPException(
status_code=500,
detail="未检测到人脸"
)
# 将 eigenvalue 转为 str
eigenvalue = str(eigenvalue)
# 其他接口(获取单条/列表、更新、删除、获取图片)与之前一致,无需修改
# ------------------------------
# 2. 获取单个人脸记录不变、用自增ID查询
# 2. 获取单个人脸记录
# ------------------------------
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
async def get_face(
face_id: int, # 这里的 ID 是数据库自增的、前端从创建响应中获取
current_user: UserResponse = Depends(get_current_user)
):
async def get_face(face_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = "SELECT * FROM face WHERE id = %s"
query = """
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = %s
"""
cursor.execute(query, (face_id,))
face = cursor.fetchone()
if not face:
raise HTTPException(
status_code=404,
detail=f"ID为 {face_id} 的人脸记录不存在"
)
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
return APIResponse(
code=200,
message="人脸记录查询成功",
message="查询成功",
data=FaceResponse(**face)
)
except MySQLError as e:
# 改为使用HTTPException
raise HTTPException(
status_code=500,
detail=f"查询人脸记录失败: {str(e)}"
) from e
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 后续 3.获取所有、4.更新、5.删除 接口(修复异常处理)
# ------------------------------
# 3. 获取所有人脸记录(不变)
# 3. 获取人脸列表
# ------------------------------
@router.get("", response_model=APIResponse, summary="获取所有人脸记录")
async def get_all_faces(
@router.get("", response_model=APIResponse, summary="获取人脸列表(分页+筛选)")
async def get_face_list(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: str = Query(None),
has_eigenvalue: bool = Query(None)
):
conn = None
cursor = None
@ -164,50 +142,66 @@ async def get_all_faces(
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = "SELECT * FROM face ORDER BY id" # 按自增ID排序
cursor.execute(query)
faces = cursor.fetchall()
where_clause = []
params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%")
if has_eigenvalue is not None:
where_clause.append("eigenvalue IS NOT NULL" if has_eigenvalue else "eigenvalue IS NULL")
# 总记录数
count_query = "SELECT COUNT(*) AS total FROM face"
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 列表数据
offset = (page - 1) * page_size
list_query = """
SELECT id, name, address, created_at, updated_at
FROM face
"""
if where_clause:
list_query += " WHERE " + " AND ".join(where_clause)
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_query, params)
face_list = cursor.fetchall()
return APIResponse(
code=200,
message="所有人脸记录查询成功",
data=[FaceResponse(** face) for face in faces]
message=f"获取成功(共{total}条)",
data=FaceListResponse(
total=total,
faces=[FaceResponse(**face) for face in face_list]
)
)
except MySQLError as e:
raise HTTPException(
status_code=500,
detail=f"查询所有人脸记录失败: {str(e)}"
) from e
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 4. 更新人脸记录不变、用自增ID更新
# 4. 更新人脸记录
# ------------------------------
@router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录")
async def update_face(
face_id: int,
face_update: FaceUpdateRequest,
current_user: UserResponse = Depends(get_current_user)
):
async def update_face(face_id: int, face_update: FaceUpdateRequest):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查记录是否存在
check_query = "SELECT id FROM face WHERE id = %s"
cursor.execute(check_query, (face_id,))
existing_face = cursor.fetchone()
if not existing_face:
raise HTTPException(
status_code=404,
detail=f"ID为 {face_id} 的人脸记录不存在"
)
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
exist_face = cursor.fetchone()
if not exist_face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
old_db_path = exist_face["address"]
# 构建更新语句
update_fields = []
params = []
if face_update.name is not None:
@ -216,6 +210,18 @@ async def update_face(
if face_update.eigenvalue is not None:
update_fields.append("eigenvalue = %s")
params.append(face_update.eigenvalue)
if face_update.address is not None:
# 删除旧图片(相对路径转绝对路径)
if old_db_path:
old_abs_path = Path(old_db_path).resolve()
if old_abs_path.exists():
try:
old_abs_path.unlink() # 使用Path方法删除更安全
print(f"[FaceRouter] 已删除旧图片:{old_abs_path}")
except Exception as e:
print(f"[FaceRouter] 删除旧图片失败:{str(e)}")
update_fields.append("address = %s")
params.append(face_update.address)
if not update_fields:
raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
@ -225,117 +231,143 @@ async def update_face(
cursor.execute(update_query, params)
conn.commit()
# 查询更新后记录
select_query = "SELECT * FROM face WHERE id = %s"
cursor.execute(select_query, (face_id,))
cursor.execute("""
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = %s
""", (face_id,))
updated_face = cursor.fetchone()
return APIResponse(
code=200,
message="人脸记录更新成功",
message="更新成功",
data=FaceResponse(**updated_face)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"更新人脸记录失败: {str(e)}"
) from e
raise HTTPException(status_code=500, detail=f"更新失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 5. 删除人脸记录不变、用自增ID删除
# 5. 删除人脸记录
# ------------------------------
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
async def delete_face(
face_id: int,
current_user: UserResponse = Depends(get_current_user)
):
async def delete_face(face_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
check_query = "SELECT id FROM face WHERE id = %s"
cursor.execute(check_query, (face_id,))
existing_face = cursor.fetchone()
if not existing_face:
raise HTTPException(
status_code=404,
detail=f"ID为 {face_id} 的人脸记录不存在"
)
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
exist_face = cursor.fetchone()
if not exist_face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
old_db_path = exist_face["address"]
delete_query = "DELETE FROM face WHERE id = %s"
cursor.execute(delete_query, (face_id,))
cursor.execute("DELETE FROM face WHERE id = %s", (face_id,))
conn.commit()
# 删除图片
extra_msg = ""
if old_db_path:
old_abs_path = Path(old_db_path).resolve()
if old_abs_path.exists():
try:
old_abs_path.unlink()
print(f"[FaceRouter] 已删除图片:{old_abs_path}")
extra_msg = "(已同步删除图片)"
except Exception as e:
print(f"[FaceRouter] 删除图片失败:{str(e)}")
extra_msg = "(图片删除失败)"
else:
extra_msg = "(图片不存在)"
else:
extra_msg = "(无关联图片)"
return APIResponse(
code=200,
message=f"ID为 {face_id}人脸记录删除成功",
message=f"ID为 {face_id} 的记录删除成功 {extra_msg}",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"删除人脸记录失败: {str(e)}"
) from e
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
def get_all_face_name_with_eigenvalue() -> dict:
"""
获取所有人脸的名称及其对应的特征值、组成字典返回
key: 人脸名称name
value: 人脸特征值eigenvalue、若名称重复则返回平均特征值
注: 过滤掉name为None的记录、避免字典key为None的情况
"""
# ------------------------------
# 6. 获取人脸图片
# ------------------------------
@router.get("/{face_id}/image", summary="获取人脸图片")
async def get_face_image(face_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = "SELECT address, name FROM face WHERE id = %s"
cursor.execute(query, (face_id,))
face = cursor.fetchone()
if not face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
db_path = face["address"]
abs_path = Path(db_path).resolve() # 转为绝对路径
if not db_path or not abs_path.exists():
raise HTTPException(status_code=404, detail=f"图片不存在(路径:{db_path}")
return FileResponse(
path=abs_path,
filename=f"face_{face_id}_{face['name'] or '未命名'}.{db_path.split('.')[-1]}",
media_type=f"image/{db_path.split('.')[-1]}"
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"获取图片失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 内部工具方法
# ------------------------------
def get_all_face_name_with_eigenvalue() -> dict:
conn = None
cursor = None
try:
# 1. 建立数据库连接并获取游标dictionary=True使结果以字典形式返回
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 2. 执行SQL查询: 只获取name非空的记录、减少数据传输
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
cursor.execute(query)
faces = cursor.fetchall() # 返回结果: 列表套字典、如 [{"name":"张三","eigenvalue":...}, ...]
faces = cursor.fetchall()
# 3. 收集同一名称对应的所有特征值(处理名称重复场景)
name_to_eigenvalues = {}
for face in faces:
name = face["name"]
eigenvalue = face["eigenvalue"]
# 若名称已存在、追加特征值;否则新建列表存储
if name in name_to_eigenvalues:
name_to_eigenvalues[name].append(eigenvalue)
else:
name_to_eigenvalues[name] = [eigenvalue]
# 4. 构建最终字典: 重复名称取平均、唯一名称直接取特征值
face_dict = {}
for name, eigenvalues in name_to_eigenvalues.items():
# 处理特征值: 多个则求平均、单个则直接使用
if len(eigenvalues) > 1:
# 调用外部方法计算平均特征值需确保binary_face_feature_handler已正确导入
face_dict[name] = get_average_feature(eigenvalues)
else:
# 取列表中唯一的特征值避免value为列表类型
face_dict[name] = eigenvalues[0]
return face_dict
except MySQLError as e:
# 捕获数据库异常、添加上下文信息后重新抛出(便于定位问题)
raise Exception(f"获取人脸名称与特征值失败: {str(e)}") from e
raise Exception(f"获取人脸特征失败: {str(e)}") from e
finally:
# 5. 无论是否异常、均释放数据库连接和游标(避免资源泄漏)
db.close_connection(conn, cursor)

497
service/model_service.py Normal file
View File

@ -0,0 +1,497 @@
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
from fastapi.responses import FileResponse
from mysql.connector import Error as MySQLError
import os
import shutil
from pathlib import Path
from datetime import datetime
# 复用项目依赖
from ds.db import db
from schema.model_schema import (
ModelCreateRequest,
ModelUpdateRequest,
ModelResponse,
ModelListResponse
)
from schema.response_schema import APIResponse
from util.model_util import load_yolo_model # 使用修复后的模型加载工具
# 路径配置
CURRENT_FILE_PATH = Path(__file__).resolve()
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent
MODEL_SAVE_ROOT = PROJECT_ROOT / "resource" / "models"
MODEL_SAVE_ROOT.mkdir(exist_ok=True, parents=True)
DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep
# 模型限制
ALLOWED_MODEL_EXT = {"pt"}
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
# 全局模型变量
global _yolo_model
_yolo_model = None
router = APIRouter(prefix="/models", tags=["模型管理"])
# 工具函数:验证模型路径
def get_valid_model_abs_path(relative_path: str) -> str:
try:
relative_path = relative_path.replace("/", os.sep)
model_abs_path = PROJECT_ROOT / relative_path
model_abs_path = model_abs_path.resolve()
model_abs_path_str = str(model_abs_path)
if not model_abs_path_str.startswith(str(MODEL_SAVE_ROOT)):
raise HTTPException(
status_code=400,
detail=f"模型路径非法!允许目录:{str(MODEL_SAVE_ROOT)},当前路径:{model_abs_path_str}"
)
if not model_abs_path.exists():
raise HTTPException(
status_code=404,
detail=f"模型文件不存在!路径:{model_abs_path_str}"
)
if not model_abs_path.is_file():
raise HTTPException(
status_code=400,
detail=f"路径不是文件!路径:{model_abs_path_str}"
)
file_size = model_abs_path.stat().st_size
if file_size > MAX_MODEL_SIZE:
raise HTTPException(
status_code=400,
detail=f"模型文件过大({file_size // 1024 // 1024}MB超过限制{MAX_MODEL_SIZE // 1024 // 1024}MB"
)
file_ext = model_abs_path.suffix.lower()
if file_ext not in [f".{ext}" for ext in ALLOWED_MODEL_EXT]:
raise HTTPException(
status_code=400,
detail=f"模型格式非法(仅支持{ALLOWED_MODEL_EXT})!当前格式:{file_ext}"
)
print(f"[模型路径校验] 成功!路径:{model_abs_path_str},大小:{file_size // 1024}KB")
return model_abs_path_str
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"路径处理失败:{str(e)}"
) from e
# 1. 上传模型
@router.post("", response_model=APIResponse, summary="上传YOLO模型.pt格式")
async def upload_model(
name: str = Form(..., description="模型名称"),
description: str = Form(None, description="模型描述"),
is_default: bool = Form(False, description="是否设为默认模型"),
file: UploadFile = File(..., description=f"YOLO模型文件.pt最大{MAX_MODEL_SIZE // 1024 // 1024}MB")
):
conn = None
cursor = None
saved_file_path = None
try:
# 校验文件
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else ""
if file_ext not in ALLOWED_MODEL_EXT:
raise HTTPException(
status_code=400,
detail=f"仅支持{ALLOWED_MODEL_EXT}格式,当前:{file_ext}"
)
if file.size > MAX_MODEL_SIZE:
raise HTTPException(
status_code=400,
detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB当前{file.size // 1024 // 1024}MB"
)
# 保存文件
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
safe_filename = f"model_{timestamp}_{file.filename.replace(' ', '_')}"
saved_file_path = MODEL_SAVE_ROOT / safe_filename
with open(saved_file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
saved_file_path.chmod(0o644) # 设置权限
# 数据库路径处理
db_relative_path = str(saved_file_path).replace(DB_PATH_PREFIX_TO_REMOVE, "").replace(os.sep, "/")
# 数据库操作
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
if is_default:
cursor.execute("UPDATE model SET is_default = 0")
insert_sql = """
INSERT INTO model (name, path, is_default, description, file_size)
VALUES (%s, %s, %s, %s, %s)
"""
cursor.execute(insert_sql, (name, db_relative_path, 1 if is_default else 0, description, file.size))
conn.commit()
cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()")
new_model = cursor.fetchone()
if not new_model:
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
# 加载默认模型
global _yolo_model
if is_default:
valid_abs_path = get_valid_model_abs_path(db_relative_path)
_yolo_model = load_yolo_model(valid_abs_path)
if not _yolo_model:
raise HTTPException(
status_code=500,
detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path}"
)
return APIResponse(
code=201,
message=f"模型上传成功ID{new_model['id']}",
data=ModelResponse(**new_model)
)
except MySQLError as e:
if conn:
conn.rollback()
if saved_file_path and saved_file_path.exists():
saved_file_path.unlink()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
except Exception as e:
if saved_file_path and saved_file_path.exists():
saved_file_path.unlink()
raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e
finally:
await file.close()
db.close_connection(conn, cursor)
# 2. 获取模型列表
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
async def get_model_list(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: str = Query(None),
is_default: bool = Query(None)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
where_clause = []
params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%")
if is_default is not None:
where_clause.append("is_default = %s")
params.append(1 if is_default else 0)
# 总记录数
count_sql = "SELECT COUNT(*) AS total FROM model"
if where_clause:
count_sql += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_sql, params)
total = cursor.fetchone()["total"]
# 分页数据
offset = (page - 1) * page_size
list_sql = "SELECT * FROM model"
if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause)
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_sql, params)
model_list = cursor.fetchall()
return APIResponse(
code=200,
message=f"获取成功!共{total}条记录",
data=ModelListResponse(
total=total,
models=[ModelResponse(**model) for model in model_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 3. 获取默认模型
@router.get("/default", response_model=APIResponse, summary="获取当前默认模型")
async def get_default_model():
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE is_default = 1")
default_model = cursor.fetchone()
if not default_model:
raise HTTPException(status_code=404, detail="暂无默认模型")
valid_abs_path = get_valid_model_abs_path(default_model["path"])
global _yolo_model
if not _yolo_model:
_yolo_model = load_yolo_model(valid_abs_path)
if not _yolo_model:
raise HTTPException(
status_code=500,
detail=f"默认模型存在,但加载失败(路径:{valid_abs_path}"
)
return APIResponse(
code=200,
message="默认模型查询成功",
data=ModelResponse(**default_model)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 4. 获取单个模型详情
@router.get("/{model_id}", response_model=APIResponse, summary="获取单个模型详情")
async def get_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
model = cursor.fetchone()
if not model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
try:
model_abs_path = get_valid_model_abs_path(model["path"])
except HTTPException as e:
return APIResponse(
code=200,
message=f"查询成功,但路径异常:{e.detail}",
data=ModelResponse(**model)
)
return APIResponse(
code=200,
message="查询成功",
data=ModelResponse(**model)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 5. 更新模型信息
@router.put("/{model_id}", response_model=APIResponse, summary="更新模型信息")
async def update_model(model_id: int, model_update: ModelUpdateRequest):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
exist_model = cursor.fetchone()
if not exist_model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
update_fields = []
params = []
if model_update.name is not None:
update_fields.append("name = %s")
params.append(model_update.name)
if model_update.description is not None:
update_fields.append("description = %s")
params.append(model_update.description)
need_load_default = False
if model_update.is_default is not None:
if model_update.is_default:
cursor.execute("UPDATE model SET is_default = 0")
update_fields.append("is_default = 1")
need_load_default = True
else:
cursor.execute("SELECT COUNT(*) AS cnt FROM model WHERE is_default = 1")
default_count = cursor.fetchone()["cnt"]
if default_count == 1 and exist_model["is_default"]:
raise HTTPException(
status_code=400,
detail="当前是唯一默认模型,不可取消!"
)
update_fields.append("is_default = 0")
if not update_fields:
raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
params.append(model_id)
update_sql = f"""
UPDATE model
SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP
WHERE id = %s
"""
cursor.execute(update_sql, params)
conn.commit()
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
updated_model = cursor.fetchone()
global _yolo_model
if need_load_default:
valid_abs_path = get_valid_model_abs_path(updated_model["path"])
_yolo_model = load_yolo_model(valid_abs_path)
if not _yolo_model:
raise HTTPException(
status_code=500,
detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path}"
)
return APIResponse(
code=200,
message="模型更新成功",
data=ModelResponse(**updated_model)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 6. 删除模型
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
async def delete_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
exist_model = cursor.fetchone()
if not exist_model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
if exist_model["is_default"]:
raise HTTPException(status_code=400, detail="默认模型不可删除!")
try:
model_abs_path_str = get_valid_model_abs_path(exist_model["path"])
model_abs_path = Path(model_abs_path_str)
except HTTPException as e:
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
conn.commit()
return APIResponse(
code=200,
message=f"记录删除成功,文件异常:{e.detail}",
data=None
)
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
conn.commit()
extra_msg = ""
try:
model_abs_path.unlink()
extra_msg = f"(已删除文件)"
except Exception as e:
extra_msg = f"(文件删除失败:{str(e)}"
global _yolo_model
if _yolo_model and str(_yolo_model.model_path) == model_abs_path_str:
_yolo_model = None
print(f"[模型删除] 已清空全局模型(路径:{model_abs_path_str}")
return APIResponse(
code=200,
message=f"模型删除成功ID{model_id} {extra_msg}",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 7. 下载模型文件
@router.get("/{model_id}/download", summary="下载模型文件")
async def download_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
model = cursor.fetchone()
if not model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
valid_abs_path = get_valid_model_abs_path(model["path"])
model_abs_path = Path(valid_abs_path)
return FileResponse(
path=model_abs_path,
filename=f"model_{model_id}_{model['name']}.pt",
media_type="application/octet-stream"
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 对外提供当前模型
def get_current_yolo_model():
"""供检测模块获取当前加载的模型"""
global _yolo_model
if not _yolo_model:
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT path FROM model WHERE is_default = 1")
default_model = cursor.fetchone()
if not default_model:
print("[get_current_yolo_model] 暂无默认模型")
return None
valid_abs_path = get_valid_model_abs_path(default_model["path"])
_yolo_model = load_yolo_model(valid_abs_path)
if _yolo_model:
print(f"[get_current_yolo_model] 自动加载默认模型成功")
else:
print(f"[get_current_yolo_model] 自动加载默认模型失败")
except Exception as e:
print(f"[get_current_yolo_model] 加载失败:{str(e)}")
finally:
db.close_connection(conn, cursor)
return _yolo_model

View File

@ -1,6 +1,7 @@
from datetime import timedelta
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Query
from mysql.connector import Error as MySQLError
from ds.db import db
@ -11,7 +12,7 @@ from middle.auth_middleware import (
verify_password,
create_access_token,
ACCESS_TOKEN_EXPIRE_MINUTES,
get_current_user
get_current_user # 仅保留登录用户校验移除is_admin导入
)
# 创建用户接口路由(前缀 /users、标签用于 Swagger 分类)
@ -27,7 +28,7 @@ router = APIRouter(
@router.post("/register", response_model=APIResponse, summary="用户注册")
async def user_register(request: UserRegisterRequest):
"""
用户注册:
用户注册:
- 校验用户名是否已存在
- 加密密码后插入数据库
- 返回注册成功信息
@ -78,7 +79,7 @@ async def user_register(request: UserRegisterRequest):
@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token")
async def user_login(request: UserLoginRequest):
"""
用户登录:
用户登录:
- 校验用户名是否存在
- 校验密码是否正确
- 生成 JWT Token 并返回
@ -142,7 +143,7 @@ async def get_current_user_info(
current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件
):
"""
获取当前登录用户信息:
获取当前登录用户信息:
- 需在请求头携带 Token格式: Bearer <token>
- 认证通过后返回用户信息
"""
@ -152,3 +153,98 @@ async def get_current_user_info(
data=current_user
)
# ------------------------------
# 4. 获取用户列表(仅需登录权限)
# ------------------------------
@router.get("/list", response_model=APIResponse, summary="获取用户列表")
async def get_user_list(
page: int = Query(1, ge=1, description="页码从1开始"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100之间"),
username: Optional[str] = Query(None, description="用户名模糊搜索"),
current_user: UserResponse = Depends(get_current_user) # 仅需登录即可访问(移除管理员校验)
):
"""
获取用户列表:
- 需登录权限(请求头携带 Token: Bearer <token>
- 支持分页查询page=页码page_size=每页条数)
- 支持用户名模糊搜索(如输入"test"可匹配"test123""admin_test"等)
- 仅返回用户ID、用户名、创建时间、更新时间不包含密码等敏感信息
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 计算分页偏移量page从1开始偏移量=(页码-1*每页条数)
offset = (page - 1) * page_size
# 基础查询(仅查非敏感字段)
base_query = """
SELECT id, username, created_at, updated_at
FROM users
"""
# 总条数查询(用于分页计算)
count_query = "SELECT COUNT(*) as total FROM users"
# 条件拼接(支持用户名模糊搜索)
conditions = []
params = []
if username:
conditions.append("username LIKE %s")
params.append(f"%{username}%") # 模糊匹配:%表示任意字符
# 构建最终查询语句
if conditions:
where_clause = " WHERE " + " AND ".join(conditions)
final_query = f"{base_query}{where_clause} LIMIT %s OFFSET %s"
final_count_query = f"{count_query}{where_clause}"
params.extend([page_size, offset]) # 追加分页参数
else:
final_query = f"{base_query} LIMIT %s OFFSET %s"
final_count_query = count_query
params = [page_size, offset]
# 1. 查询用户列表数据
cursor.execute(final_query, params)
users = cursor.fetchall()
# 2. 查询总条数(用于计算总页数)
count_params = [f"%{username}%"] if username else []
cursor.execute(final_count_query, count_params)
total = cursor.fetchone()["total"]
# 3. 转换为UserResponse模型确保字段匹配
user_list = [
UserResponse(
id=user["id"],
username=user["username"],
created_at=user["created_at"],
updated_at=user["updated_at"]
)
for user in users
]
# 4. 计算总页数向上取整如11条数据每页10条=2页
total_pages = (total + page_size - 1) // page_size
# 返回结果(包含列表和分页信息)
return APIResponse(
code=200,
message="获取用户列表成功",
data={
"users": user_list,
"pagination": {
"page": page, # 当前页码
"page_size": page_size, # 每页条数
"total": total, # 总数据量
"total_pages": total_pages # 总页数
}
}
)
except MySQLError as e:
raise Exception(f"获取用户列表失败: {str(e)}") from e
finally:
# 无论成功失败,都关闭数据库连接
db.close_connection(conn, cursor)