目前可以成功动态更换模型运行的

This commit is contained in:
2025-09-12 14:05:09 +08:00
parent 435b2a0e6c
commit 4be7f7bf14
13 changed files with 1518 additions and 325 deletions

View File

@ -1,162 +1,140 @@
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, Request
from fastapi.responses import FileResponse
from mysql.connector import Error as MySQLError
import os
from pathlib import Path
from ds.db import db
from schema.face_schema import FaceCreateRequest, FaceUpdateRequest, FaceResponse
from schema.response_schema import APIResponse
from middle.auth_middleware import get_current_user
from schema.user_schema import UserResponse
from util.face_util import add_binary_data,get_average_feature
#初始化实例
router = APIRouter(
prefix="/faces",
tags=["人脸管理"]
from schema.face_schema import (
FaceCreateRequest,
FaceUpdateRequest,
FaceResponse,
FaceListResponse
)
from schema.response_schema import APIResponse
from util.face_util import add_binary_data, get_average_feature
from util.file_util import save_face_to_up_images
router = APIRouter(prefix="/faces", tags=["人脸管理"])
# ------------------------------
# 1. 创建人脸记录(核心修正: ID 数据库自增、前端无需传
# 1. 创建人脸记录(使用修复后的路径
# ------------------------------
@router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件、ID自增")
@router.post("", response_model=APIResponse, summary="创建人脸记录")
async def create_face(
# 前端仅需传: name可选、Form格式、file必传、文件
request: Request,
name: str = Form(None, max_length=255, description="名称(可选)"),
file: UploadFile = File(..., description="人脸文件(必传、暂不处理内容")
file: UploadFile = File(..., description="人脸文件(必传)")
):
"""
创建人脸记录:
- 需登录认证
- 前端传参: multipart/form-data 表单name 可选、file 必传)
- ID 由数据库自动生成、无需前端传入
- 暂不处理文件内容、eigenvalue 设为 None
"""
# 调用你的方法
conn = None
cursor = None
try:
# 1. 用模型校验 name仅校验长度、无需ID
face_create = FaceCreateRequest(name=name)
client_ip = request.client.host if request.client else ""
if not client_ip:
raise HTTPException(status_code=400, detail="无法获取客户端IP")
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 把文件转为二进制数组
# 读取图片并保存(使用修复后的路径逻辑)
file_content = await file.read()
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else "jpg"
save_result = save_face_to_up_images(
client_ip=client_ip,
face_name=name,
image_bytes=file_content,
image_format=file_ext
)
if not save_result["success"]:
raise HTTPException(status_code=500, detail=f"图片保存失败:{save_result['msg']}")
db_image_path = save_result["db_path"] # 从修复后的方法获取路径
# 计算特征
flag, eigenvalue = add_binary_data(file_content)
# 提取人脸特征
detect_success, detect_result = add_binary_data(file_content)
if not detect_success:
raise HTTPException(status_code=400, detail=f"人脸检测失败:{detect_result}")
eigenvalue = detect_result
if flag == False:
raise HTTPException(
status_code=500,
detail="未检测到人脸"
)
# 打印数组长度
print(f"文件大小: {len(file_content)} 字节")
# 2. 插入数据库: 无需传 ID自增、只传 name 和 eigenvalueNone
# 插入数据库
insert_query = """
INSERT INTO face (name, eigenvalue)
VALUES (%s, %s)
INSERT INTO face (name, eigenvalue, address)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_query, (face_create.name, str(eigenvalue)))
cursor.execute(insert_query, (face_create.name, str(eigenvalue), db_image_path))
conn.commit()
# 3. 获取数据库自动生成的 ID关键: 用 LAST_INSERT_ID() 查刚插入的记录
select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()"
cursor.execute(select_new_query)
# 查询新记录
cursor.execute("""
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = LAST_INSERT_ID()
""")
created_face = cursor.fetchone()
if not created_face:
raise HTTPException(
status_code=500,
detail="创建人脸记录成功、但无法获取新创建的记录"
)
raise HTTPException(status_code=500, detail="创建成功但无法获取记录")
return APIResponse(
code=201,
message=f"人脸记录创建成功ID: {created_face['id']}、文件名: {file.filename}",
data=FaceResponse(** created_face)
message=f"人脸记录创建成功ID: {created_face['id']}",
data=FaceResponse(**created_face)
)
except MySQLError as e:
if conn:
conn.rollback()
# 改为使用HTTPException
raise HTTPException(
status_code=500,
detail=f"创建人脸记录失败: {str(e)}"
) from e
raise HTTPException(status_code=500, detail=f"创建失败: {str(e)}") from e
except Exception as e:
# 捕获其他可能的异常
raise HTTPException(
status_code=500,
detail=f"服务器错误: {str(e)}"
) from e
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
finally:
await file.close() # 关闭文件流
await file.close()
db.close_connection(conn, cursor)
# 调用人脸识别得到特征值(这里可以添加你的人脸识别逻辑)
flag, eigenvalue = add_binary_data(file_content)
if flag == False:
raise HTTPException(
status_code=500,
detail="未检测到人脸"
)
# 将 eigenvalue 转为 str
eigenvalue = str(eigenvalue)
# 其他接口(获取单条/列表、更新、删除、获取图片)与之前一致,无需修改
# ------------------------------
# 2. 获取单个人脸记录不变、用自增ID查询
# 2. 获取单个人脸记录
# ------------------------------
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
async def get_face(
face_id: int, # 这里的 ID 是数据库自增的、前端从创建响应中获取
current_user: UserResponse = Depends(get_current_user)
):
async def get_face(face_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = "SELECT * FROM face WHERE id = %s"
query = """
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = %s
"""
cursor.execute(query, (face_id,))
face = cursor.fetchone()
if not face:
raise HTTPException(
status_code=404,
detail=f"ID为 {face_id} 的人脸记录不存在"
)
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
return APIResponse(
code=200,
message="人脸记录查询成功",
message="查询成功",
data=FaceResponse(**face)
)
except MySQLError as e:
# 改为使用HTTPException
raise HTTPException(
status_code=500,
detail=f"查询人脸记录失败: {str(e)}"
) from e
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 后续 3.获取所有、4.更新、5.删除 接口(修复异常处理)
# ------------------------------
# 3. 获取所有人脸记录(不变)
# 3. 获取人脸列表
# ------------------------------
@router.get("", response_model=APIResponse, summary="获取所有人脸记录")
async def get_all_faces(
@router.get("", response_model=APIResponse, summary="获取人脸列表(分页+筛选)")
async def get_face_list(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: str = Query(None),
has_eigenvalue: bool = Query(None)
):
conn = None
cursor = None
@ -164,50 +142,66 @@ async def get_all_faces(
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = "SELECT * FROM face ORDER BY id" # 按自增ID排序
cursor.execute(query)
faces = cursor.fetchall()
where_clause = []
params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%")
if has_eigenvalue is not None:
where_clause.append("eigenvalue IS NOT NULL" if has_eigenvalue else "eigenvalue IS NULL")
# 总记录数
count_query = "SELECT COUNT(*) AS total FROM face"
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 列表数据
offset = (page - 1) * page_size
list_query = """
SELECT id, name, address, created_at, updated_at
FROM face
"""
if where_clause:
list_query += " WHERE " + " AND ".join(where_clause)
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_query, params)
face_list = cursor.fetchall()
return APIResponse(
code=200,
message="所有人脸记录查询成功",
data=[FaceResponse(** face) for face in faces]
message=f"获取成功(共{total}条)",
data=FaceListResponse(
total=total,
faces=[FaceResponse(**face) for face in face_list]
)
)
except MySQLError as e:
raise HTTPException(
status_code=500,
detail=f"查询所有人脸记录失败: {str(e)}"
) from e
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 4. 更新人脸记录不变、用自增ID更新
# 4. 更新人脸记录
# ------------------------------
@router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录")
async def update_face(
face_id: int,
face_update: FaceUpdateRequest,
current_user: UserResponse = Depends(get_current_user)
):
async def update_face(face_id: int, face_update: FaceUpdateRequest):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查记录是否存在
check_query = "SELECT id FROM face WHERE id = %s"
cursor.execute(check_query, (face_id,))
existing_face = cursor.fetchone()
if not existing_face:
raise HTTPException(
status_code=404,
detail=f"ID为 {face_id} 的人脸记录不存在"
)
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
exist_face = cursor.fetchone()
if not exist_face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
old_db_path = exist_face["address"]
# 构建更新语句
update_fields = []
params = []
if face_update.name is not None:
@ -216,6 +210,18 @@ async def update_face(
if face_update.eigenvalue is not None:
update_fields.append("eigenvalue = %s")
params.append(face_update.eigenvalue)
if face_update.address is not None:
# 删除旧图片(相对路径转绝对路径)
if old_db_path:
old_abs_path = Path(old_db_path).resolve()
if old_abs_path.exists():
try:
old_abs_path.unlink() # 使用Path方法删除更安全
print(f"[FaceRouter] 已删除旧图片:{old_abs_path}")
except Exception as e:
print(f"[FaceRouter] 删除旧图片失败:{str(e)}")
update_fields.append("address = %s")
params.append(face_update.address)
if not update_fields:
raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
@ -225,117 +231,143 @@ async def update_face(
cursor.execute(update_query, params)
conn.commit()
# 查询更新后记录
select_query = "SELECT * FROM face WHERE id = %s"
cursor.execute(select_query, (face_id,))
cursor.execute("""
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = %s
""", (face_id,))
updated_face = cursor.fetchone()
return APIResponse(
code=200,
message="人脸记录更新成功",
message="更新成功",
data=FaceResponse(**updated_face)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"更新人脸记录失败: {str(e)}"
) from e
raise HTTPException(status_code=500, detail=f"更新失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 5. 删除人脸记录不变、用自增ID删除
# 5. 删除人脸记录
# ------------------------------
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
async def delete_face(
face_id: int,
current_user: UserResponse = Depends(get_current_user)
):
async def delete_face(face_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
check_query = "SELECT id FROM face WHERE id = %s"
cursor.execute(check_query, (face_id,))
existing_face = cursor.fetchone()
if not existing_face:
raise HTTPException(
status_code=404,
detail=f"ID为 {face_id} 的人脸记录不存在"
)
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
exist_face = cursor.fetchone()
if not exist_face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
old_db_path = exist_face["address"]
delete_query = "DELETE FROM face WHERE id = %s"
cursor.execute(delete_query, (face_id,))
cursor.execute("DELETE FROM face WHERE id = %s", (face_id,))
conn.commit()
# 删除图片
extra_msg = ""
if old_db_path:
old_abs_path = Path(old_db_path).resolve()
if old_abs_path.exists():
try:
old_abs_path.unlink()
print(f"[FaceRouter] 已删除图片:{old_abs_path}")
extra_msg = "(已同步删除图片)"
except Exception as e:
print(f"[FaceRouter] 删除图片失败:{str(e)}")
extra_msg = "(图片删除失败)"
else:
extra_msg = "(图片不存在)"
else:
extra_msg = "(无关联图片)"
return APIResponse(
code=200,
message=f"ID为 {face_id}人脸记录删除成功",
message=f"ID为 {face_id} 的记录删除成功 {extra_msg}",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"删除人脸记录失败: {str(e)}"
) from e
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
def get_all_face_name_with_eigenvalue() -> dict:
"""
获取所有人脸的名称及其对应的特征值、组成字典返回
key: 人脸名称name
value: 人脸特征值eigenvalue、若名称重复则返回平均特征值
注: 过滤掉name为None的记录、避免字典key为None的情况
"""
# ------------------------------
# 6. 获取人脸图片
# ------------------------------
@router.get("/{face_id}/image", summary="获取人脸图片")
async def get_face_image(face_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = "SELECT address, name FROM face WHERE id = %s"
cursor.execute(query, (face_id,))
face = cursor.fetchone()
if not face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
db_path = face["address"]
abs_path = Path(db_path).resolve() # 转为绝对路径
if not db_path or not abs_path.exists():
raise HTTPException(status_code=404, detail=f"图片不存在(路径:{db_path}")
return FileResponse(
path=abs_path,
filename=f"face_{face_id}_{face['name'] or '未命名'}.{db_path.split('.')[-1]}",
media_type=f"image/{db_path.split('.')[-1]}"
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"获取图片失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 内部工具方法
# ------------------------------
def get_all_face_name_with_eigenvalue() -> dict:
conn = None
cursor = None
try:
# 1. 建立数据库连接并获取游标dictionary=True使结果以字典形式返回
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 2. 执行SQL查询: 只获取name非空的记录、减少数据传输
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
cursor.execute(query)
faces = cursor.fetchall() # 返回结果: 列表套字典、如 [{"name":"张三","eigenvalue":...}, ...]
faces = cursor.fetchall()
# 3. 收集同一名称对应的所有特征值(处理名称重复场景)
name_to_eigenvalues = {}
for face in faces:
name = face["name"]
eigenvalue = face["eigenvalue"]
# 若名称已存在、追加特征值;否则新建列表存储
if name in name_to_eigenvalues:
name_to_eigenvalues[name].append(eigenvalue)
else:
name_to_eigenvalues[name] = [eigenvalue]
# 4. 构建最终字典: 重复名称取平均、唯一名称直接取特征值
face_dict = {}
for name, eigenvalues in name_to_eigenvalues.items():
# 处理特征值: 多个则求平均、单个则直接使用
if len(eigenvalues) > 1:
# 调用外部方法计算平均特征值需确保binary_face_feature_handler已正确导入
face_dict[name] = get_average_feature(eigenvalues)
else:
# 取列表中唯一的特征值避免value为列表类型
face_dict[name] = eigenvalues[0]
return face_dict
except MySQLError as e:
# 捕获数据库异常、添加上下文信息后重新抛出(便于定位问题)
raise Exception(f"获取人脸名称与特征值失败: {str(e)}") from e
raise Exception(f"获取人脸特征失败: {str(e)}") from e
finally:
# 5. 无论是否异常、均释放数据库连接和游标(避免资源泄漏)
db.close_connection(conn, cursor)