目前可以成功动态更换模型运行的
This commit is contained in:
@ -1,162 +1,140 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, Request
|
||||
from fastapi.responses import FileResponse
|
||||
from mysql.connector import Error as MySQLError
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from ds.db import db
|
||||
from schema.face_schema import FaceCreateRequest, FaceUpdateRequest, FaceResponse
|
||||
from schema.response_schema import APIResponse
|
||||
from middle.auth_middleware import get_current_user
|
||||
from schema.user_schema import UserResponse
|
||||
|
||||
from util.face_util import add_binary_data,get_average_feature
|
||||
#初始化实例
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/faces",
|
||||
tags=["人脸管理"]
|
||||
from schema.face_schema import (
|
||||
FaceCreateRequest,
|
||||
FaceUpdateRequest,
|
||||
FaceResponse,
|
||||
FaceListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
|
||||
from util.face_util import add_binary_data, get_average_feature
|
||||
from util.file_util import save_face_to_up_images
|
||||
|
||||
router = APIRouter(prefix="/faces", tags=["人脸管理"])
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 1. 创建人脸记录(核心修正: ID 数据库自增、前端无需传)
|
||||
# 1. 创建人脸记录(使用修复后的路径)
|
||||
# ------------------------------
|
||||
@router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件、ID自增)")
|
||||
@router.post("", response_model=APIResponse, summary="创建人脸记录")
|
||||
async def create_face(
|
||||
# 前端仅需传: name(可选、Form格式)、file(必传、文件)
|
||||
request: Request,
|
||||
name: str = Form(None, max_length=255, description="名称(可选)"),
|
||||
file: UploadFile = File(..., description="人脸文件(必传、暂不处理内容)")
|
||||
file: UploadFile = File(..., description="人脸文件(必传)")
|
||||
):
|
||||
"""
|
||||
创建人脸记录:
|
||||
- 需登录认证
|
||||
- 前端传参: multipart/form-data 表单(name 可选、file 必传)
|
||||
- ID 由数据库自动生成、无需前端传入
|
||||
- 暂不处理文件内容、eigenvalue 设为 None
|
||||
"""
|
||||
|
||||
# 调用你的方法
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 1. 用模型校验 name(仅校验长度、无需ID)
|
||||
face_create = FaceCreateRequest(name=name)
|
||||
client_ip = request.client.host if request.client else ""
|
||||
if not client_ip:
|
||||
raise HTTPException(status_code=400, detail="无法获取客户端IP")
|
||||
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 把文件转为二进制数组
|
||||
# 读取图片并保存(使用修复后的路径逻辑)
|
||||
file_content = await file.read()
|
||||
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else "jpg"
|
||||
save_result = save_face_to_up_images(
|
||||
client_ip=client_ip,
|
||||
face_name=name,
|
||||
image_bytes=file_content,
|
||||
image_format=file_ext
|
||||
)
|
||||
if not save_result["success"]:
|
||||
raise HTTPException(status_code=500, detail=f"图片保存失败:{save_result['msg']}")
|
||||
db_image_path = save_result["db_path"] # 从修复后的方法获取路径
|
||||
|
||||
# 计算特征值
|
||||
flag, eigenvalue = add_binary_data(file_content)
|
||||
# 提取人脸特征
|
||||
detect_success, detect_result = add_binary_data(file_content)
|
||||
if not detect_success:
|
||||
raise HTTPException(status_code=400, detail=f"人脸检测失败:{detect_result}")
|
||||
eigenvalue = detect_result
|
||||
|
||||
if flag == False:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="未检测到人脸"
|
||||
)
|
||||
|
||||
# 打印数组长度
|
||||
print(f"文件大小: {len(file_content)} 字节")
|
||||
|
||||
# 2. 插入数据库: 无需传 ID(自增)、只传 name 和 eigenvalue(None)
|
||||
# 插入数据库
|
||||
insert_query = """
|
||||
INSERT INTO face (name, eigenvalue)
|
||||
VALUES (%s, %s)
|
||||
INSERT INTO face (name, eigenvalue, address)
|
||||
VALUES (%s, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_query, (face_create.name, str(eigenvalue)))
|
||||
cursor.execute(insert_query, (face_create.name, str(eigenvalue), db_image_path))
|
||||
conn.commit()
|
||||
|
||||
# 3. 获取数据库自动生成的 ID(关键: 用 LAST_INSERT_ID() 查刚插入的记录)
|
||||
select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()"
|
||||
cursor.execute(select_new_query)
|
||||
# 查询新记录
|
||||
cursor.execute("""
|
||||
SELECT id, name, address, created_at, updated_at
|
||||
FROM face
|
||||
WHERE id = LAST_INSERT_ID()
|
||||
""")
|
||||
created_face = cursor.fetchone()
|
||||
|
||||
if not created_face:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="创建人脸记录成功、但无法获取新创建的记录"
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="创建成功但无法获取记录")
|
||||
|
||||
return APIResponse(
|
||||
code=201,
|
||||
message=f"人脸记录创建成功(ID: {created_face['id']}、文件名: {file.filename})",
|
||||
data=FaceResponse(** created_face)
|
||||
message=f"人脸记录创建成功(ID: {created_face['id']})",
|
||||
data=FaceResponse(**created_face)
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
# 改为使用HTTPException
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"创建人脸记录失败: {str(e)}"
|
||||
) from e
|
||||
raise HTTPException(status_code=500, detail=f"创建失败: {str(e)}") from e
|
||||
except Exception as e:
|
||||
# 捕获其他可能的异常
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"服务器错误: {str(e)}"
|
||||
) from e
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
finally:
|
||||
await file.close() # 关闭文件流
|
||||
await file.close()
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
# 调用人脸识别得到特征值(这里可以添加你的人脸识别逻辑)
|
||||
flag, eigenvalue = add_binary_data(file_content)
|
||||
if flag == False:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="未检测到人脸"
|
||||
)
|
||||
|
||||
# 将 eigenvalue 转为 str
|
||||
eigenvalue = str(eigenvalue)
|
||||
|
||||
# 其他接口(获取单条/列表、更新、删除、获取图片)与之前一致,无需修改
|
||||
# ------------------------------
|
||||
# 2. 获取单个人脸记录(不变、用自增ID查询)
|
||||
# 2. 获取单个人脸记录
|
||||
# ------------------------------
|
||||
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
|
||||
async def get_face(
|
||||
face_id: int, # 这里的 ID 是数据库自增的、前端从创建响应中获取
|
||||
current_user: UserResponse = Depends(get_current_user)
|
||||
):
|
||||
async def get_face(face_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
query = "SELECT * FROM face WHERE id = %s"
|
||||
query = """
|
||||
SELECT id, name, address, created_at, updated_at
|
||||
FROM face
|
||||
WHERE id = %s
|
||||
"""
|
||||
cursor.execute(query, (face_id,))
|
||||
face = cursor.fetchone()
|
||||
|
||||
if not face:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ID为 {face_id} 的人脸记录不存在"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="人脸记录查询成功",
|
||||
message="查询成功",
|
||||
data=FaceResponse(**face)
|
||||
)
|
||||
except MySQLError as e:
|
||||
# 改为使用HTTPException
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"查询人脸记录失败: {str(e)}"
|
||||
) from e
|
||||
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 后续 3.获取所有、4.更新、5.删除 接口(修复异常处理)
|
||||
# ------------------------------
|
||||
# 3. 获取所有人脸记录(不变)
|
||||
# 3. 获取人脸列表
|
||||
# ------------------------------
|
||||
@router.get("", response_model=APIResponse, summary="获取所有人脸记录")
|
||||
async def get_all_faces(
|
||||
@router.get("", response_model=APIResponse, summary="获取人脸列表(分页+筛选)")
|
||||
async def get_face_list(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
name: str = Query(None),
|
||||
has_eigenvalue: bool = Query(None)
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
@ -164,50 +142,66 @@ async def get_all_faces(
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
query = "SELECT * FROM face ORDER BY id" # 按自增ID排序
|
||||
cursor.execute(query)
|
||||
faces = cursor.fetchall()
|
||||
where_clause = []
|
||||
params = []
|
||||
if name:
|
||||
where_clause.append("name LIKE %s")
|
||||
params.append(f"%{name}%")
|
||||
if has_eigenvalue is not None:
|
||||
where_clause.append("eigenvalue IS NOT NULL" if has_eigenvalue else "eigenvalue IS NULL")
|
||||
|
||||
# 总记录数
|
||||
count_query = "SELECT COUNT(*) AS total FROM face"
|
||||
if where_clause:
|
||||
count_query += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_query, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 列表数据
|
||||
offset = (page - 1) * page_size
|
||||
list_query = """
|
||||
SELECT id, name, address, created_at, updated_at
|
||||
FROM face
|
||||
"""
|
||||
if where_clause:
|
||||
list_query += " WHERE " + " AND ".join(where_clause)
|
||||
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(list_query, params)
|
||||
face_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="所有人脸记录查询成功",
|
||||
data=[FaceResponse(** face) for face in faces]
|
||||
message=f"获取成功(共{total}条)",
|
||||
data=FaceListResponse(
|
||||
total=total,
|
||||
faces=[FaceResponse(**face) for face in face_list]
|
||||
)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"查询所有人脸记录失败: {str(e)}"
|
||||
) from e
|
||||
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 4. 更新人脸记录(不变、用自增ID更新)
|
||||
# 4. 更新人脸记录
|
||||
# ------------------------------
|
||||
@router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录")
|
||||
async def update_face(
|
||||
face_id: int,
|
||||
face_update: FaceUpdateRequest,
|
||||
current_user: UserResponse = Depends(get_current_user)
|
||||
):
|
||||
async def update_face(face_id: int, face_update: FaceUpdateRequest):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查记录是否存在
|
||||
check_query = "SELECT id FROM face WHERE id = %s"
|
||||
cursor.execute(check_query, (face_id,))
|
||||
existing_face = cursor.fetchone()
|
||||
if not existing_face:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ID为 {face_id} 的人脸记录不存在"
|
||||
)
|
||||
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
|
||||
exist_face = cursor.fetchone()
|
||||
if not exist_face:
|
||||
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
|
||||
old_db_path = exist_face["address"]
|
||||
|
||||
# 构建更新语句
|
||||
update_fields = []
|
||||
params = []
|
||||
if face_update.name is not None:
|
||||
@ -216,6 +210,18 @@ async def update_face(
|
||||
if face_update.eigenvalue is not None:
|
||||
update_fields.append("eigenvalue = %s")
|
||||
params.append(face_update.eigenvalue)
|
||||
if face_update.address is not None:
|
||||
# 删除旧图片(相对路径转绝对路径)
|
||||
if old_db_path:
|
||||
old_abs_path = Path(old_db_path).resolve()
|
||||
if old_abs_path.exists():
|
||||
try:
|
||||
old_abs_path.unlink() # 使用Path方法删除更安全
|
||||
print(f"[FaceRouter] 已删除旧图片:{old_abs_path}")
|
||||
except Exception as e:
|
||||
print(f"[FaceRouter] 删除旧图片失败:{str(e)}")
|
||||
update_fields.append("address = %s")
|
||||
params.append(face_update.address)
|
||||
|
||||
if not update_fields:
|
||||
raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
|
||||
@ -225,117 +231,143 @@ async def update_face(
|
||||
cursor.execute(update_query, params)
|
||||
conn.commit()
|
||||
|
||||
# 查询更新后记录
|
||||
select_query = "SELECT * FROM face WHERE id = %s"
|
||||
cursor.execute(select_query, (face_id,))
|
||||
cursor.execute("""
|
||||
SELECT id, name, address, created_at, updated_at
|
||||
FROM face
|
||||
WHERE id = %s
|
||||
""", (face_id,))
|
||||
updated_face = cursor.fetchone()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="人脸记录更新成功",
|
||||
message="更新成功",
|
||||
data=FaceResponse(**updated_face)
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"更新人脸记录失败: {str(e)}"
|
||||
) from e
|
||||
raise HTTPException(status_code=500, detail=f"更新失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 5. 删除人脸记录(不变、用自增ID删除)
|
||||
# 5. 删除人脸记录
|
||||
# ------------------------------
|
||||
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
|
||||
async def delete_face(
|
||||
face_id: int,
|
||||
current_user: UserResponse = Depends(get_current_user)
|
||||
):
|
||||
async def delete_face(face_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
check_query = "SELECT id FROM face WHERE id = %s"
|
||||
cursor.execute(check_query, (face_id,))
|
||||
existing_face = cursor.fetchone()
|
||||
if not existing_face:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ID为 {face_id} 的人脸记录不存在"
|
||||
)
|
||||
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
|
||||
exist_face = cursor.fetchone()
|
||||
if not exist_face:
|
||||
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
|
||||
old_db_path = exist_face["address"]
|
||||
|
||||
delete_query = "DELETE FROM face WHERE id = %s"
|
||||
cursor.execute(delete_query, (face_id,))
|
||||
cursor.execute("DELETE FROM face WHERE id = %s", (face_id,))
|
||||
conn.commit()
|
||||
|
||||
# 删除图片
|
||||
extra_msg = ""
|
||||
if old_db_path:
|
||||
old_abs_path = Path(old_db_path).resolve()
|
||||
if old_abs_path.exists():
|
||||
try:
|
||||
old_abs_path.unlink()
|
||||
print(f"[FaceRouter] 已删除图片:{old_abs_path}")
|
||||
extra_msg = "(已同步删除图片)"
|
||||
except Exception as e:
|
||||
print(f"[FaceRouter] 删除图片失败:{str(e)}")
|
||||
extra_msg = "(图片删除失败)"
|
||||
else:
|
||||
extra_msg = "(图片不存在)"
|
||||
else:
|
||||
extra_msg = "(无关联图片)"
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"ID为 {face_id} 的人脸记录删除成功",
|
||||
message=f"ID为 {face_id} 的记录删除成功 {extra_msg}",
|
||||
data=None
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"删除人脸记录失败: {str(e)}"
|
||||
) from e
|
||||
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
def get_all_face_name_with_eigenvalue() -> dict:
|
||||
"""
|
||||
获取所有人脸的名称及其对应的特征值、组成字典返回
|
||||
key: 人脸名称(name)
|
||||
value: 人脸特征值(eigenvalue)、若名称重复则返回平均特征值
|
||||
注: 过滤掉name为None的记录、避免字典key为None的情况
|
||||
"""
|
||||
# ------------------------------
|
||||
# 6. 获取人脸图片
|
||||
# ------------------------------
|
||||
@router.get("/{face_id}/image", summary="获取人脸图片")
|
||||
async def get_face_image(face_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
query = "SELECT address, name FROM face WHERE id = %s"
|
||||
cursor.execute(query, (face_id,))
|
||||
face = cursor.fetchone()
|
||||
|
||||
if not face:
|
||||
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
|
||||
|
||||
db_path = face["address"]
|
||||
abs_path = Path(db_path).resolve() # 转为绝对路径
|
||||
if not db_path or not abs_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"图片不存在(路径:{db_path})")
|
||||
|
||||
return FileResponse(
|
||||
path=abs_path,
|
||||
filename=f"face_{face_id}_{face['name'] or '未命名'}.{db_path.split('.')[-1]}",
|
||||
media_type=f"image/{db_path.split('.')[-1]}"
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"获取图片失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 内部工具方法
|
||||
# ------------------------------
|
||||
def get_all_face_name_with_eigenvalue() -> dict:
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 1. 建立数据库连接并获取游标(dictionary=True使结果以字典形式返回)
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 2. 执行SQL查询: 只获取name非空的记录、减少数据传输
|
||||
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
|
||||
cursor.execute(query)
|
||||
faces = cursor.fetchall() # 返回结果: 列表套字典、如 [{"name":"张三","eigenvalue":...}, ...]
|
||||
faces = cursor.fetchall()
|
||||
|
||||
# 3. 收集同一名称对应的所有特征值(处理名称重复场景)
|
||||
name_to_eigenvalues = {}
|
||||
for face in faces:
|
||||
name = face["name"]
|
||||
eigenvalue = face["eigenvalue"]
|
||||
# 若名称已存在、追加特征值;否则新建列表存储
|
||||
if name in name_to_eigenvalues:
|
||||
name_to_eigenvalues[name].append(eigenvalue)
|
||||
else:
|
||||
name_to_eigenvalues[name] = [eigenvalue]
|
||||
|
||||
# 4. 构建最终字典: 重复名称取平均、唯一名称直接取特征值
|
||||
face_dict = {}
|
||||
for name, eigenvalues in name_to_eigenvalues.items():
|
||||
|
||||
# 处理特征值: 多个则求平均、单个则直接使用
|
||||
if len(eigenvalues) > 1:
|
||||
# 调用外部方法计算平均特征值(需确保binary_face_feature_handler已正确导入)
|
||||
face_dict[name] = get_average_feature(eigenvalues)
|
||||
else:
|
||||
# 取列表中唯一的特征值(避免value为列表类型)
|
||||
face_dict[name] = eigenvalues[0]
|
||||
|
||||
return face_dict
|
||||
|
||||
except MySQLError as e:
|
||||
# 捕获数据库异常、添加上下文信息后重新抛出(便于定位问题)
|
||||
raise Exception(f"获取人脸名称与特征值失败: {str(e)}") from e
|
||||
raise Exception(f"获取人脸特征失败: {str(e)}") from e
|
||||
finally:
|
||||
# 5. 无论是否异常、均释放数据库连接和游标(避免资源泄漏)
|
||||
db.close_connection(conn, cursor)
|
497
service/model_service.py
Normal file
497
service/model_service.py
Normal file
@ -0,0 +1,497 @@
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from mysql.connector import Error as MySQLError
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# 复用项目依赖
|
||||
from ds.db import db
|
||||
from schema.model_schema import (
|
||||
ModelCreateRequest,
|
||||
ModelUpdateRequest,
|
||||
ModelResponse,
|
||||
ModelListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
from util.model_util import load_yolo_model # 使用修复后的模型加载工具
|
||||
|
||||
# 路径配置
|
||||
CURRENT_FILE_PATH = Path(__file__).resolve()
|
||||
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent
|
||||
MODEL_SAVE_ROOT = PROJECT_ROOT / "resource" / "models"
|
||||
MODEL_SAVE_ROOT.mkdir(exist_ok=True, parents=True)
|
||||
DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep
|
||||
|
||||
# 模型限制
|
||||
ALLOWED_MODEL_EXT = {"pt"}
|
||||
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
|
||||
|
||||
# 全局模型变量
|
||||
global _yolo_model
|
||||
_yolo_model = None
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["模型管理"])
|
||||
|
||||
|
||||
# 工具函数:验证模型路径
|
||||
def get_valid_model_abs_path(relative_path: str) -> str:
|
||||
try:
|
||||
relative_path = relative_path.replace("/", os.sep)
|
||||
model_abs_path = PROJECT_ROOT / relative_path
|
||||
model_abs_path = model_abs_path.resolve()
|
||||
model_abs_path_str = str(model_abs_path)
|
||||
|
||||
if not model_abs_path_str.startswith(str(MODEL_SAVE_ROOT)):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"模型路径非法!允许目录:{str(MODEL_SAVE_ROOT)},当前路径:{model_abs_path_str}"
|
||||
)
|
||||
|
||||
if not model_abs_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"模型文件不存在!路径:{model_abs_path_str}"
|
||||
)
|
||||
|
||||
if not model_abs_path.is_file():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"路径不是文件!路径:{model_abs_path_str}"
|
||||
)
|
||||
|
||||
file_size = model_abs_path.stat().st_size
|
||||
if file_size > MAX_MODEL_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"模型文件过大({file_size // 1024 // 1024}MB),超过限制{MAX_MODEL_SIZE // 1024 // 1024}MB"
|
||||
)
|
||||
|
||||
file_ext = model_abs_path.suffix.lower()
|
||||
if file_ext not in [f".{ext}" for ext in ALLOWED_MODEL_EXT]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"模型格式非法(仅支持{ALLOWED_MODEL_EXT})!当前格式:{file_ext}"
|
||||
)
|
||||
|
||||
print(f"[模型路径校验] 成功!路径:{model_abs_path_str},大小:{file_size // 1024}KB")
|
||||
return model_abs_path_str
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"路径处理失败:{str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
# 1. 上传模型
|
||||
@router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)")
|
||||
async def upload_model(
|
||||
name: str = Form(..., description="模型名称"),
|
||||
description: str = Form(None, description="模型描述"),
|
||||
is_default: bool = Form(False, description="是否设为默认模型"),
|
||||
file: UploadFile = File(..., description=f"YOLO模型文件(.pt,最大{MAX_MODEL_SIZE // 1024 // 1024}MB)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
saved_file_path = None
|
||||
try:
|
||||
# 校验文件
|
||||
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else ""
|
||||
if file_ext not in ALLOWED_MODEL_EXT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"仅支持{ALLOWED_MODEL_EXT}格式,当前:{file_ext}"
|
||||
)
|
||||
if file.size > MAX_MODEL_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB,当前{file.size // 1024 // 1024}MB"
|
||||
)
|
||||
|
||||
# 保存文件
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
safe_filename = f"model_{timestamp}_{file.filename.replace(' ', '_')}"
|
||||
saved_file_path = MODEL_SAVE_ROOT / safe_filename
|
||||
with open(saved_file_path, "wb") as f:
|
||||
shutil.copyfileobj(file.file, f)
|
||||
saved_file_path.chmod(0o644) # 设置权限
|
||||
|
||||
# 数据库路径处理
|
||||
db_relative_path = str(saved_file_path).replace(DB_PATH_PREFIX_TO_REMOVE, "").replace(os.sep, "/")
|
||||
|
||||
# 数据库操作
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
if is_default:
|
||||
cursor.execute("UPDATE model SET is_default = 0")
|
||||
|
||||
insert_sql = """
|
||||
INSERT INTO model (name, path, is_default, description, file_size)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_sql, (name, db_relative_path, 1 if is_default else 0, description, file.size))
|
||||
conn.commit()
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()")
|
||||
new_model = cursor.fetchone()
|
||||
if not new_model:
|
||||
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
|
||||
|
||||
# 加载默认模型
|
||||
global _yolo_model
|
||||
if is_default:
|
||||
valid_abs_path = get_valid_model_abs_path(db_relative_path)
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if not _yolo_model:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path})"
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
code=201,
|
||||
message=f"模型上传成功!ID:{new_model['id']}",
|
||||
data=ModelResponse(**new_model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
if saved_file_path and saved_file_path.exists():
|
||||
saved_file_path.unlink()
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
except Exception as e:
|
||||
if saved_file_path and saved_file_path.exists():
|
||||
saved_file_path.unlink()
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e
|
||||
finally:
|
||||
await file.close()
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 2. 获取模型列表
|
||||
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
|
||||
async def get_model_list(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
name: str = Query(None),
|
||||
is_default: bool = Query(None)
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
where_clause = []
|
||||
params = []
|
||||
if name:
|
||||
where_clause.append("name LIKE %s")
|
||||
params.append(f"%{name}%")
|
||||
if is_default is not None:
|
||||
where_clause.append("is_default = %s")
|
||||
params.append(1 if is_default else 0)
|
||||
|
||||
# 总记录数
|
||||
count_sql = "SELECT COUNT(*) AS total FROM model"
|
||||
if where_clause:
|
||||
count_sql += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_sql, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 分页数据
|
||||
offset = (page - 1) * page_size
|
||||
list_sql = "SELECT * FROM model"
|
||||
if where_clause:
|
||||
list_sql += " WHERE " + " AND ".join(where_clause)
|
||||
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(list_sql, params)
|
||||
model_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"获取成功!共{total}条记录",
|
||||
data=ModelListResponse(
|
||||
total=total,
|
||||
models=[ModelResponse(**model) for model in model_list]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 3. 获取默认模型
|
||||
@router.get("/default", response_model=APIResponse, summary="获取当前默认模型")
|
||||
async def get_default_model():
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE is_default = 1")
|
||||
default_model = cursor.fetchone()
|
||||
|
||||
if not default_model:
|
||||
raise HTTPException(status_code=404, detail="暂无默认模型")
|
||||
|
||||
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
||||
global _yolo_model
|
||||
|
||||
if not _yolo_model:
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if not _yolo_model:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"默认模型存在,但加载失败(路径:{valid_abs_path})"
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="默认模型查询成功",
|
||||
data=ModelResponse(**default_model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 4. 获取单个模型详情
|
||||
@router.get("/{model_id}", response_model=APIResponse, summary="获取单个模型详情")
|
||||
async def get_model(model_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
model = cursor.fetchone()
|
||||
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}")
|
||||
|
||||
try:
|
||||
model_abs_path = get_valid_model_abs_path(model["path"])
|
||||
except HTTPException as e:
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"查询成功,但路径异常:{e.detail}",
|
||||
data=ModelResponse(**model)
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="查询成功",
|
||||
data=ModelResponse(**model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 5. 更新模型信息
|
||||
@router.put("/{model_id}", response_model=APIResponse, summary="更新模型信息")
|
||||
async def update_model(model_id: int, model_update: ModelUpdateRequest):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
exist_model = cursor.fetchone()
|
||||
if not exist_model:
|
||||
raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}")
|
||||
|
||||
update_fields = []
|
||||
params = []
|
||||
if model_update.name is not None:
|
||||
update_fields.append("name = %s")
|
||||
params.append(model_update.name)
|
||||
if model_update.description is not None:
|
||||
update_fields.append("description = %s")
|
||||
params.append(model_update.description)
|
||||
|
||||
need_load_default = False
|
||||
if model_update.is_default is not None:
|
||||
if model_update.is_default:
|
||||
cursor.execute("UPDATE model SET is_default = 0")
|
||||
update_fields.append("is_default = 1")
|
||||
need_load_default = True
|
||||
else:
|
||||
cursor.execute("SELECT COUNT(*) AS cnt FROM model WHERE is_default = 1")
|
||||
default_count = cursor.fetchone()["cnt"]
|
||||
if default_count == 1 and exist_model["is_default"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="当前是唯一默认模型,不可取消!"
|
||||
)
|
||||
update_fields.append("is_default = 0")
|
||||
|
||||
if not update_fields:
|
||||
raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
|
||||
|
||||
params.append(model_id)
|
||||
update_sql = f"""
|
||||
UPDATE model
|
||||
SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = %s
|
||||
"""
|
||||
cursor.execute(update_sql, params)
|
||||
conn.commit()
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
updated_model = cursor.fetchone()
|
||||
|
||||
global _yolo_model
|
||||
if need_load_default:
|
||||
valid_abs_path = get_valid_model_abs_path(updated_model["path"])
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if not _yolo_model:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path})"
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="模型更新成功",
|
||||
data=ModelResponse(**updated_model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 6. 删除模型
|
||||
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
|
||||
async def delete_model(model_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
exist_model = cursor.fetchone()
|
||||
if not exist_model:
|
||||
raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}")
|
||||
if exist_model["is_default"]:
|
||||
raise HTTPException(status_code=400, detail="默认模型不可删除!")
|
||||
|
||||
try:
|
||||
model_abs_path_str = get_valid_model_abs_path(exist_model["path"])
|
||||
model_abs_path = Path(model_abs_path_str)
|
||||
except HTTPException as e:
|
||||
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
|
||||
conn.commit()
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"记录删除成功,文件异常:{e.detail}",
|
||||
data=None
|
||||
)
|
||||
|
||||
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
|
||||
conn.commit()
|
||||
|
||||
extra_msg = ""
|
||||
try:
|
||||
model_abs_path.unlink()
|
||||
extra_msg = f"(已删除文件)"
|
||||
except Exception as e:
|
||||
extra_msg = f"(文件删除失败:{str(e)})"
|
||||
|
||||
global _yolo_model
|
||||
if _yolo_model and str(_yolo_model.model_path) == model_abs_path_str:
|
||||
_yolo_model = None
|
||||
print(f"[模型删除] 已清空全局模型(路径:{model_abs_path_str})")
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"模型删除成功!ID:{model_id} {extra_msg}",
|
||||
data=None
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 7. 下载模型文件
|
||||
@router.get("/{model_id}/download", summary="下载模型文件")
|
||||
async def download_model(model_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
model = cursor.fetchone()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}")
|
||||
|
||||
valid_abs_path = get_valid_model_abs_path(model["path"])
|
||||
model_abs_path = Path(valid_abs_path)
|
||||
|
||||
return FileResponse(
|
||||
path=model_abs_path,
|
||||
filename=f"model_{model_id}_{model['name']}.pt",
|
||||
media_type="application/octet-stream"
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 对外提供当前模型
|
||||
def get_current_yolo_model():
|
||||
"""供检测模块获取当前加载的模型"""
|
||||
global _yolo_model
|
||||
if not _yolo_model:
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute("SELECT path FROM model WHERE is_default = 1")
|
||||
default_model = cursor.fetchone()
|
||||
if not default_model:
|
||||
print("[get_current_yolo_model] 暂无默认模型")
|
||||
return None
|
||||
|
||||
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if _yolo_model:
|
||||
print(f"[get_current_yolo_model] 自动加载默认模型成功")
|
||||
else:
|
||||
print(f"[get_current_yolo_model] 自动加载默认模型失败")
|
||||
except Exception as e:
|
||||
print(f"[get_current_yolo_model] 加载失败:{str(e)}")
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
return _yolo_model
|
@ -1,6 +1,7 @@
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
@ -11,7 +12,7 @@ from middle.auth_middleware import (
|
||||
verify_password,
|
||||
create_access_token,
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
get_current_user
|
||||
get_current_user # 仅保留登录用户校验,移除is_admin导入
|
||||
)
|
||||
|
||||
# 创建用户接口路由(前缀 /users、标签用于 Swagger 分类)
|
||||
@ -27,7 +28,7 @@ router = APIRouter(
|
||||
@router.post("/register", response_model=APIResponse, summary="用户注册")
|
||||
async def user_register(request: UserRegisterRequest):
|
||||
"""
|
||||
用户注册:
|
||||
用户注册:
|
||||
- 校验用户名是否已存在
|
||||
- 加密密码后插入数据库
|
||||
- 返回注册成功信息
|
||||
@ -78,7 +79,7 @@ async def user_register(request: UserRegisterRequest):
|
||||
@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token)")
|
||||
async def user_login(request: UserLoginRequest):
|
||||
"""
|
||||
用户登录:
|
||||
用户登录:
|
||||
- 校验用户名是否存在
|
||||
- 校验密码是否正确
|
||||
- 生成 JWT Token 并返回
|
||||
@ -142,7 +143,7 @@ async def get_current_user_info(
|
||||
current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件
|
||||
):
|
||||
"""
|
||||
获取当前登录用户信息:
|
||||
获取当前登录用户信息:
|
||||
- 需在请求头携带 Token(格式: Bearer <token>)
|
||||
- 认证通过后返回用户信息
|
||||
"""
|
||||
@ -152,3 +153,98 @@ async def get_current_user_info(
|
||||
data=current_user
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 4. 获取用户列表(仅需登录权限)
|
||||
# ------------------------------
|
||||
@router.get("/list", response_model=APIResponse, summary="获取用户列表")
|
||||
async def get_user_list(
|
||||
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"),
|
||||
username: Optional[str] = Query(None, description="用户名模糊搜索"),
|
||||
current_user: UserResponse = Depends(get_current_user) # 仅需登录即可访问(移除管理员校验)
|
||||
):
|
||||
"""
|
||||
获取用户列表:
|
||||
- 需登录权限(请求头携带 Token: Bearer <token>)
|
||||
- 支持分页查询(page=页码,page_size=每页条数)
|
||||
- 支持用户名模糊搜索(如输入"test"可匹配"test123"、"admin_test"等)
|
||||
- 仅返回用户ID、用户名、创建时间、更新时间(不包含密码等敏感信息)
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 计算分页偏移量(page从1开始,偏移量=(页码-1)*每页条数)
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 基础查询(仅查非敏感字段)
|
||||
base_query = """
|
||||
SELECT id, username, created_at, updated_at
|
||||
FROM users
|
||||
"""
|
||||
# 总条数查询(用于分页计算)
|
||||
count_query = "SELECT COUNT(*) as total FROM users"
|
||||
|
||||
# 条件拼接(支持用户名模糊搜索)
|
||||
conditions = []
|
||||
params = []
|
||||
if username:
|
||||
conditions.append("username LIKE %s")
|
||||
params.append(f"%{username}%") # 模糊匹配:%表示任意字符
|
||||
|
||||
# 构建最终查询语句
|
||||
if conditions:
|
||||
where_clause = " WHERE " + " AND ".join(conditions)
|
||||
final_query = f"{base_query}{where_clause} LIMIT %s OFFSET %s"
|
||||
final_count_query = f"{count_query}{where_clause}"
|
||||
params.extend([page_size, offset]) # 追加分页参数
|
||||
else:
|
||||
final_query = f"{base_query} LIMIT %s OFFSET %s"
|
||||
final_count_query = count_query
|
||||
params = [page_size, offset]
|
||||
|
||||
# 1. 查询用户列表数据
|
||||
cursor.execute(final_query, params)
|
||||
users = cursor.fetchall()
|
||||
|
||||
# 2. 查询总条数(用于计算总页数)
|
||||
count_params = [f"%{username}%"] if username else []
|
||||
cursor.execute(final_count_query, count_params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 3. 转换为UserResponse模型(确保字段匹配)
|
||||
user_list = [
|
||||
UserResponse(
|
||||
id=user["id"],
|
||||
username=user["username"],
|
||||
created_at=user["created_at"],
|
||||
updated_at=user["updated_at"]
|
||||
)
|
||||
for user in users
|
||||
]
|
||||
|
||||
# 4. 计算总页数(向上取整,如11条数据每页10条=2页)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# 返回结果(包含列表和分页信息)
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取用户列表成功",
|
||||
data={
|
||||
"users": user_list,
|
||||
"pagination": {
|
||||
"page": page, # 当前页码
|
||||
"page_size": page_size, # 每页条数
|
||||
"total": total, # 总数据量
|
||||
"total_pages": total_pages # 总页数
|
||||
}
|
||||
}
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取用户列表失败: {str(e)}") from e
|
||||
finally:
|
||||
# 无论成功失败,都关闭数据库连接
|
||||
db.close_connection(conn, cursor)
|
Reference in New Issue
Block a user