287 lines
9.8 KiB
Python
287 lines
9.8 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
|
||
from mysql.connector import Error as MySQLError
|
||
|
||
from ds.db import db
|
||
from schema.face_schema import FaceCreateRequest, FaceUpdateRequest, FaceResponse
|
||
from schema.response_schema import APIResponse
|
||
from middle.auth_middleware import get_current_user
|
||
from schema.user_schema import UserResponse
|
||
from ocr.feature_extraction import BinaryFaceFeatureHandler
|
||
|
||
router = APIRouter(
|
||
prefix="/faces",
|
||
tags=["人脸管理"]
|
||
)
|
||
|
||
# 创建 BinaryFaceFeatureHandler 的实例
|
||
binary_face_feature_handler = BinaryFaceFeatureHandler()
|
||
|
||
|
||
# ------------------------------
|
||
# 1. 创建人脸记录(核心修正:ID 数据库自增,前端无需传)
|
||
# ------------------------------
|
||
@router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件,ID自增)")
|
||
async def create_face(
|
||
# 前端仅需传:name(可选,Form格式)、file(必传,文件)
|
||
name: str = Form(None, max_length=255, description="名称(可选)"),
|
||
file: UploadFile = File(..., description="人脸文件(必传,暂不处理内容)")
|
||
):
|
||
"""
|
||
创建人脸记录:
|
||
- 需登录认证
|
||
- 前端传参:multipart/form-data 表单(name 可选,file 必传)
|
||
- ID 由数据库自动生成,无需前端传入
|
||
- 暂不处理文件内容,eigenvalue 设为 None
|
||
"""
|
||
conn = None
|
||
cursor = None
|
||
try:
|
||
# 1. 用模型校验 name(仅校验长度,无需ID)
|
||
face_create = FaceCreateRequest(name=name)
|
||
|
||
conn = db.get_connection()
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
# 把文件转为二进制数组
|
||
file_content = await file.read()
|
||
|
||
# 调用人脸识别得到特征值
|
||
|
||
# 2. 插入数据库:无需传 ID(自增),只传 name 和 eigenvalue(None)
|
||
insert_query = """
|
||
INSERT INTO face (name, eigenvalue)
|
||
VALUES (%s, %s)
|
||
"""
|
||
cursor.execute(insert_query, (face_create.name, None))
|
||
conn.commit()
|
||
|
||
# 3. 获取数据库自动生成的 ID(关键:用 LAST_INSERT_ID() 查刚插入的记录)
|
||
select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()"
|
||
cursor.execute(select_new_query)
|
||
created_face = cursor.fetchone()
|
||
|
||
return APIResponse(
|
||
code=201,
|
||
message=f"人脸记录创建成功(ID:{created_face['id']},文件名:{file.filename})",
|
||
data=FaceResponse(**created_face)
|
||
)
|
||
except MySQLError as e:
|
||
if conn:
|
||
conn.rollback()
|
||
raise Exception(f"创建人脸记录失败:{str(e)}") from e
|
||
finally:
|
||
await file.close() # 关闭文件流
|
||
db.close_connection(conn, cursor)
|
||
|
||
|
||
# ------------------------------
|
||
# 2. 获取单个人脸记录(不变,用自增ID查询)
|
||
# ------------------------------
|
||
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
|
||
async def get_face(
|
||
face_id: int, # 这里的 ID 是数据库自增的,前端从创建响应中获取
|
||
current_user: UserResponse = Depends(get_current_user)
|
||
):
|
||
conn = None
|
||
cursor = None
|
||
try:
|
||
conn = db.get_connection()
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
query = "SELECT * FROM face WHERE id = %s"
|
||
cursor.execute(query, (face_id,))
|
||
face = cursor.fetchone()
|
||
|
||
if not face:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail=f"ID为 {face_id} 的人脸记录不存在"
|
||
)
|
||
|
||
return APIResponse(
|
||
code=200,
|
||
message="人脸记录查询成功",
|
||
data=FaceResponse(**face)
|
||
)
|
||
except MySQLError as e:
|
||
raise Exception(f"查询人脸记录失败:{str(e)}") from e
|
||
finally:
|
||
db.close_connection(conn, cursor)
|
||
|
||
|
||
# 后续 3.获取所有、4.更新、5.删除 接口(不变,仅用数据库自增的 ID 操作,无需修改)
|
||
# ------------------------------
|
||
# 3. 获取所有人脸记录(不变)
|
||
# ------------------------------
|
||
@router.get("", response_model=APIResponse, summary="获取所有人脸记录")
|
||
async def get_all_faces(
|
||
current_user: UserResponse = Depends(get_current_user)
|
||
):
|
||
conn = None
|
||
cursor = None
|
||
try:
|
||
conn = db.get_connection()
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
query = "SELECT * FROM face ORDER BY id" # 按自增ID排序
|
||
cursor.execute(query)
|
||
faces = cursor.fetchall()
|
||
|
||
return APIResponse(
|
||
code=200,
|
||
message="所有人脸记录查询成功",
|
||
data=[FaceResponse(**face) for face in faces]
|
||
)
|
||
except MySQLError as e:
|
||
raise Exception(f"查询所有人脸记录失败:{str(e)}") from e
|
||
finally:
|
||
db.close_connection(conn, cursor)
|
||
|
||
|
||
# ------------------------------
|
||
# 4. 更新人脸记录(不变,用自增ID更新)
|
||
# ------------------------------
|
||
@router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录")
|
||
async def update_face(
|
||
face_id: int,
|
||
face_update: FaceUpdateRequest,
|
||
current_user: UserResponse = Depends(get_current_user)
|
||
):
|
||
conn = None
|
||
cursor = None
|
||
try:
|
||
conn = db.get_connection()
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
# 检查记录是否存在
|
||
check_query = "SELECT id FROM face WHERE id = %s"
|
||
cursor.execute(check_query, (face_id,))
|
||
existing_face = cursor.fetchone()
|
||
if not existing_face:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail=f"ID为 {face_id} 的人脸记录不存在"
|
||
)
|
||
|
||
# 构建更新语句
|
||
update_fields = []
|
||
params = []
|
||
if face_update.name is not None:
|
||
update_fields.append("name = %s")
|
||
params.append(face_update.name)
|
||
if face_update.eigenvalue is not None:
|
||
update_fields.append("eigenvalue = %s")
|
||
params.append(face_update.eigenvalue)
|
||
|
||
if not update_fields:
|
||
raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
|
||
|
||
params.append(face_id)
|
||
update_query = f"UPDATE face SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP WHERE id = %s"
|
||
cursor.execute(update_query, params)
|
||
conn.commit()
|
||
|
||
# 查询更新后记录
|
||
select_query = "SELECT * FROM face WHERE id = %s"
|
||
cursor.execute(select_query, (face_id,))
|
||
updated_face = cursor.fetchone()
|
||
|
||
return APIResponse(
|
||
code=200,
|
||
message="人脸记录更新成功",
|
||
data=FaceResponse(**updated_face)
|
||
)
|
||
except MySQLError as e:
|
||
if conn:
|
||
conn.rollback()
|
||
raise Exception(f"更新人脸记录失败:{str(e)}") from e
|
||
finally:
|
||
db.close_connection(conn, cursor)
|
||
|
||
|
||
# ------------------------------
|
||
# 5. 删除人脸记录(不变,用自增ID删除)
|
||
# ------------------------------
|
||
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
|
||
async def delete_face(
|
||
face_id: int,
|
||
current_user: UserResponse = Depends(get_current_user)
|
||
):
|
||
conn = None
|
||
cursor = None
|
||
try:
|
||
conn = db.get_connection()
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
check_query = "SELECT id FROM face WHERE id = %s"
|
||
cursor.execute(check_query, (face_id,))
|
||
existing_face = cursor.fetchone()
|
||
if not existing_face:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail=f"ID为 {face_id} 的人脸记录不存在"
|
||
)
|
||
|
||
delete_query = "DELETE FROM face WHERE id = %s"
|
||
cursor.execute(delete_query, (face_id,))
|
||
conn.commit()
|
||
|
||
return APIResponse(
|
||
code=200,
|
||
message=f"ID为 {face_id} 的人脸记录删除成功",
|
||
data=None
|
||
)
|
||
except MySQLError as e:
|
||
if conn:
|
||
conn.rollback()
|
||
raise Exception(f"删除人脸记录失败:{str(e)}") from e
|
||
finally:
|
||
db.close_connection(conn, cursor)
|
||
|
||
|
||
def get_all_face_name_with_eigenvalue() -> dict:
|
||
"""
|
||
获取所有人脸的名称及其对应的特征值,组成字典返回
|
||
key: 人脸名称(name)
|
||
value: 人脸特征值(eigenvalue),若名称重复则返回平均特征值
|
||
注:过滤掉name为None的记录,避免字典key为None的情况
|
||
"""
|
||
conn = None
|
||
cursor = None
|
||
try:
|
||
conn = db.get_connection()
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
# 只查询需要的字段,提高效率
|
||
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
|
||
cursor.execute(query)
|
||
faces = cursor.fetchall()
|
||
|
||
# 先收集所有名称对应的特征值列表(处理重复名称)
|
||
name_to_eigenvalues = {}
|
||
for face in faces:
|
||
name = face["name"]
|
||
eigenvalue = face["eigenvalue"]
|
||
if name in name_to_eigenvalues:
|
||
name_to_eigenvalues[name].append(eigenvalue)
|
||
else:
|
||
name_to_eigenvalues[name] = [eigenvalue]
|
||
|
||
# 构建最终字典:重复名称取平均特征值,唯一名称直接取特征值
|
||
face_dict = {}
|
||
for name, eigenvalues in name_to_eigenvalues.items():
|
||
print("调用的特征值是:" + eigenvalues)
|
||
if len(eigenvalues) > 1:
|
||
# 调用平均特征值计算方法
|
||
face_dict[name] = binary_face_feature_handler.get_average_feature(eigenvalues)
|
||
else:
|
||
face_dict[name] = eigenvalues[0]
|
||
|
||
return face_dict
|
||
|
||
except MySQLError as e:
|
||
raise Exception(f"获取人脸名称与特征值失败:{str(e)}") from e
|
||
finally:
|
||
# 确保资源释放
|
||
db.close_connection(conn, cursor)
|