from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form from mysql.connector import Error as MySQLError from ds.db import db from schema.face_schema import FaceCreateRequest, FaceUpdateRequest, FaceResponse from schema.response_schema import APIResponse from middle.auth_middleware import get_current_user from schema.user_schema import UserResponse from ocr.feature_extraction import BinaryFaceFeatureHandler router = APIRouter( prefix="/faces", tags=["人脸管理"] ) # 创建 BinaryFaceFeatureHandler 的实例 binary_face_feature_handler = BinaryFaceFeatureHandler() # ------------------------------ # 1. 创建人脸记录(核心修正:ID 数据库自增,前端无需传) # ------------------------------ @router.post("", response_model=APIResponse, summary="创建人脸记录(传名称+文件,ID自增)") async def create_face( # 前端仅需传:name(可选,Form格式)、file(必传,文件) name: str = Form(None, max_length=255, description="名称(可选)"), file: UploadFile = File(..., description="人脸文件(必传,暂不处理内容)") ): """ 创建人脸记录: - 需登录认证 - 前端传参:multipart/form-data 表单(name 可选,file 必传) - ID 由数据库自动生成,无需前端传入 - 暂不处理文件内容,eigenvalue 设为 None """ conn = None cursor = None try: # 1. 用模型校验 name(仅校验长度,无需ID) face_create = FaceCreateRequest(name=name) conn = db.get_connection() cursor = conn.cursor(dictionary=True) # 把文件转为二进制数组 file_content = await file.read() # 调用人脸识别得到特征值 # 2. 插入数据库:无需传 ID(自增),只传 name 和 eigenvalue(None) insert_query = """ INSERT INTO face (name, eigenvalue) VALUES (%s, %s) """ cursor.execute(insert_query, (face_create.name, None)) conn.commit() # 3. 获取数据库自动生成的 ID(关键:用 LAST_INSERT_ID() 查刚插入的记录) select_new_query = "SELECT * FROM face WHERE id = LAST_INSERT_ID()" cursor.execute(select_new_query) created_face = cursor.fetchone() return APIResponse( code=201, message=f"人脸记录创建成功(ID:{created_face['id']},文件名:{file.filename})", data=FaceResponse(**created_face) ) except MySQLError as e: if conn: conn.rollback() raise Exception(f"创建人脸记录失败:{str(e)}") from e finally: await file.close() # 关闭文件流 db.close_connection(conn, cursor) # ------------------------------ # 2. 获取单个人脸记录(不变,用自增ID查询) # ------------------------------ @router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录") async def get_face( face_id: int, # 这里的 ID 是数据库自增的,前端从创建响应中获取 current_user: UserResponse = Depends(get_current_user) ): conn = None cursor = None try: conn = db.get_connection() cursor = conn.cursor(dictionary=True) query = "SELECT * FROM face WHERE id = %s" cursor.execute(query, (face_id,)) face = cursor.fetchone() if not face: raise HTTPException( status_code=404, detail=f"ID为 {face_id} 的人脸记录不存在" ) return APIResponse( code=200, message="人脸记录查询成功", data=FaceResponse(**face) ) except MySQLError as e: raise Exception(f"查询人脸记录失败:{str(e)}") from e finally: db.close_connection(conn, cursor) # 后续 3.获取所有、4.更新、5.删除 接口(不变,仅用数据库自增的 ID 操作,无需修改) # ------------------------------ # 3. 获取所有人脸记录(不变) # ------------------------------ @router.get("", response_model=APIResponse, summary="获取所有人脸记录") async def get_all_faces( current_user: UserResponse = Depends(get_current_user) ): conn = None cursor = None try: conn = db.get_connection() cursor = conn.cursor(dictionary=True) query = "SELECT * FROM face ORDER BY id" # 按自增ID排序 cursor.execute(query) faces = cursor.fetchall() return APIResponse( code=200, message="所有人脸记录查询成功", data=[FaceResponse(**face) for face in faces] ) except MySQLError as e: raise Exception(f"查询所有人脸记录失败:{str(e)}") from e finally: db.close_connection(conn, cursor) # ------------------------------ # 4. 更新人脸记录(不变,用自增ID更新) # ------------------------------ @router.put("/{face_id}", response_model=APIResponse, summary="更新人脸记录") async def update_face( face_id: int, face_update: FaceUpdateRequest, current_user: UserResponse = Depends(get_current_user) ): conn = None cursor = None try: conn = db.get_connection() cursor = conn.cursor(dictionary=True) # 检查记录是否存在 check_query = "SELECT id FROM face WHERE id = %s" cursor.execute(check_query, (face_id,)) existing_face = cursor.fetchone() if not existing_face: raise HTTPException( status_code=404, detail=f"ID为 {face_id} 的人脸记录不存在" ) # 构建更新语句 update_fields = [] params = [] if face_update.name is not None: update_fields.append("name = %s") params.append(face_update.name) if face_update.eigenvalue is not None: update_fields.append("eigenvalue = %s") params.append(face_update.eigenvalue) if not update_fields: raise HTTPException(status_code=400, detail="至少需提供一个更新字段") params.append(face_id) update_query = f"UPDATE face SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP WHERE id = %s" cursor.execute(update_query, params) conn.commit() # 查询更新后记录 select_query = "SELECT * FROM face WHERE id = %s" cursor.execute(select_query, (face_id,)) updated_face = cursor.fetchone() return APIResponse( code=200, message="人脸记录更新成功", data=FaceResponse(**updated_face) ) except MySQLError as e: if conn: conn.rollback() raise Exception(f"更新人脸记录失败:{str(e)}") from e finally: db.close_connection(conn, cursor) # ------------------------------ # 5. 删除人脸记录(不变,用自增ID删除) # ------------------------------ @router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录") async def delete_face( face_id: int, current_user: UserResponse = Depends(get_current_user) ): conn = None cursor = None try: conn = db.get_connection() cursor = conn.cursor(dictionary=True) check_query = "SELECT id FROM face WHERE id = %s" cursor.execute(check_query, (face_id,)) existing_face = cursor.fetchone() if not existing_face: raise HTTPException( status_code=404, detail=f"ID为 {face_id} 的人脸记录不存在" ) delete_query = "DELETE FROM face WHERE id = %s" cursor.execute(delete_query, (face_id,)) conn.commit() return APIResponse( code=200, message=f"ID为 {face_id} 的人脸记录删除成功", data=None ) except MySQLError as e: if conn: conn.rollback() raise Exception(f"删除人脸记录失败:{str(e)}") from e finally: db.close_connection(conn, cursor) def get_all_face_name_with_eigenvalue() -> dict: """ 获取所有人脸的名称及其对应的特征值,组成字典返回 key: 人脸名称(name) value: 人脸特征值(eigenvalue),若名称重复则返回平均特征值 注:过滤掉name为None的记录,避免字典key为None的情况 """ conn = None cursor = None try: conn = db.get_connection() cursor = conn.cursor(dictionary=True) # 只查询需要的字段,提高效率 query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL" cursor.execute(query) faces = cursor.fetchall() # 先收集所有名称对应的特征值列表(处理重复名称) name_to_eigenvalues = {} for face in faces: name = face["name"] eigenvalue = face["eigenvalue"] if name in name_to_eigenvalues: name_to_eigenvalues[name].append(eigenvalue) else: name_to_eigenvalues[name] = [eigenvalue] # 构建最终字典:重复名称取平均特征值,唯一名称直接取特征值 face_dict = {} for name, eigenvalues in name_to_eigenvalues.items(): print("调用的特征值是:" + eigenvalues) if len(eigenvalues) > 1: # 调用平均特征值计算方法 face_dict[name] = binary_face_feature_handler.get_average_feature(eigenvalues) else: face_dict[name] = eigenvalues[0] return face_dict except MySQLError as e: raise Exception(f"获取人脸名称与特征值失败:{str(e)}") from e finally: # 确保资源释放 db.close_connection(conn, cursor)