326 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			326 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from io import BytesIO
 | ||
| from pathlib import Path
 | ||
| 
 | ||
| from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, Request
 | ||
| from mysql.connector import Error as MySQLError
 | ||
| 
 | ||
| from ds.db import db
 | ||
| from encryption.encrypt_decorator import encrypt_response
 | ||
| from schema.face_schema import (
 | ||
|     FaceCreateRequest,
 | ||
|     FaceResponse,
 | ||
|     FaceListResponse
 | ||
| )
 | ||
| from schema.response_schema import APIResponse
 | ||
| from service.face_service import update_face_data
 | ||
| from util.face_util import add_binary_data
 | ||
| from service.file_service import save_source_file
 | ||
| 
 | ||
| router = APIRouter(prefix="/api/faces", tags=["人脸管理"])
 | ||
| 
 | ||
| 
 | ||
| # 创建人脸记录
 | ||
| @router.post("", response_model=APIResponse, summary="创建人脸记录")
 | ||
| @encrypt_response()
 | ||
| async def create_face(
 | ||
|         request: Request,
 | ||
|         name: str = Form(None, max_length=255, description="名称(可选)"),
 | ||
|         file: UploadFile = File(..., description="人脸文件(必传)")
 | ||
| ):
 | ||
|     conn = None
 | ||
|     cursor = None
 | ||
|     try:
 | ||
|         face_create = FaceCreateRequest(name=name)
 | ||
|         client_ip = request.client.host if request.client else ""
 | ||
|         if not client_ip:
 | ||
|             raise HTTPException(status_code=400, detail="无法获取客户端IP")
 | ||
| 
 | ||
|         conn = db.get_connection()
 | ||
|         cursor = conn.cursor(dictionary=True)
 | ||
|         # 先读取文件内容
 | ||
|         file_content = await file.read()
 | ||
|         # 将文件指针重置到开头
 | ||
|         await file.seek(0)
 | ||
|         # 再保存文件
 | ||
|         path = save_source_file(file, "face")
 | ||
|         # 提取人脸特征
 | ||
|         detect_success, detect_result = add_binary_data(file_content)
 | ||
|         if not detect_success:
 | ||
|             raise HTTPException(status_code=400, detail=f"人脸检测失败:{detect_result}")
 | ||
|         eigenvalue = detect_result
 | ||
| 
 | ||
|         # 插入数据库
 | ||
|         insert_query = """
 | ||
|             INSERT INTO face (name, eigenvalue, address)
 | ||
|             VALUES (%s, %s, %s)
 | ||
|         """
 | ||
|         cursor.execute(insert_query, (face_create.name, str(eigenvalue), path))
 | ||
|         conn.commit()
 | ||
| 
 | ||
|         # 查询新记录
 | ||
|         cursor.execute("""
 | ||
|             SELECT id, name, address, created_at, updated_at 
 | ||
|             FROM face 
 | ||
|             WHERE id = LAST_INSERT_ID()
 | ||
|         """)
 | ||
|         created_face = cursor.fetchone()
 | ||
|         if not created_face:
 | ||
|             raise HTTPException(status_code=500, detail="创建成功但无法获取记录")
 | ||
| 
 | ||
| 
 | ||
|         # TODO 重新加载人脸模型
 | ||
|         update_face_data()
 | ||
| 
 | ||
| 
 | ||
|         return APIResponse(
 | ||
|             code=200,
 | ||
|             message=f"人脸记录创建成功(ID: {created_face['id']})",
 | ||
|             data=FaceResponse(**created_face)
 | ||
|         )
 | ||
|     except MySQLError as e:
 | ||
|         if conn:
 | ||
|             conn.rollback()
 | ||
|         raise HTTPException(status_code=500, detail=f"创建失败: {str(e)}") from e
 | ||
|     except Exception as e:
 | ||
|         raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
 | ||
|     finally:
 | ||
|         await file.close()
 | ||
|         db.close_connection(conn, cursor)
 | ||
| 
 | ||
| 
 | ||
| # 获取单个人脸记录
 | ||
| @router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
 | ||
| @encrypt_response()
 | ||
| async def get_face(face_id: int):
 | ||
|     conn = None
 | ||
|     cursor = None
 | ||
|     try:
 | ||
|         conn = db.get_connection()
 | ||
|         cursor = conn.cursor(dictionary=True)
 | ||
| 
 | ||
|         query = """
 | ||
|             SELECT id, name, address, created_at, updated_at 
 | ||
|             FROM face 
 | ||
|             WHERE id = %s
 | ||
|         """
 | ||
|         cursor.execute(query, (face_id,))
 | ||
|         face = cursor.fetchone()
 | ||
| 
 | ||
|         if not face:
 | ||
|             raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
 | ||
| 
 | ||
|         return APIResponse(
 | ||
|             code=200,
 | ||
|             message="查询成功",
 | ||
|             data=FaceResponse(**face)
 | ||
|         )
 | ||
|     except MySQLError as e:
 | ||
|         raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
 | ||
|     finally:
 | ||
|         db.close_connection(conn, cursor)
 | ||
| 
 | ||
| 
 | ||
| # 获取人脸列表
 | ||
| @router.get("", response_model=APIResponse, summary="获取人脸列表(分页+筛选)")
 | ||
| @encrypt_response()
 | ||
| async def get_face_list(
 | ||
|         page: int = Query(1, ge=1),
 | ||
|         page_size: int = Query(10, ge=1, le=100),
 | ||
|         name: str = Query(None),
 | ||
|         has_eigenvalue: bool = Query(None)
 | ||
| ):
 | ||
|     conn = None
 | ||
|     cursor = None
 | ||
|     try:
 | ||
|         conn = db.get_connection()
 | ||
|         cursor = conn.cursor(dictionary=True)
 | ||
| 
 | ||
|         where_clause = []
 | ||
|         params = []
 | ||
|         if name:
 | ||
|             where_clause.append("name LIKE %s")
 | ||
|             params.append(f"%{name}%")
 | ||
|         if has_eigenvalue is not None:
 | ||
|             where_clause.append("eigenvalue IS NOT NULL" if has_eigenvalue else "eigenvalue IS NULL")
 | ||
| 
 | ||
|         # 总记录数
 | ||
|         count_query = "SELECT COUNT(*) AS total FROM face"
 | ||
|         if where_clause:
 | ||
|             count_query += " WHERE " + " AND ".join(where_clause)
 | ||
|         cursor.execute(count_query, params)
 | ||
|         total = cursor.fetchone()["total"]
 | ||
| 
 | ||
|         # 列表数据
 | ||
|         offset = (page - 1) * page_size
 | ||
|         list_query = """
 | ||
|             SELECT id, name, address, created_at, updated_at 
 | ||
|             FROM face
 | ||
|         """
 | ||
|         if where_clause:
 | ||
|             list_query += " WHERE " + " AND ".join(where_clause)
 | ||
|         list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
 | ||
|         params.extend([page_size, offset])
 | ||
| 
 | ||
|         cursor.execute(list_query, params)
 | ||
|         face_list = cursor.fetchall()
 | ||
| 
 | ||
|         return APIResponse(
 | ||
|             code=200,
 | ||
|             message=f"获取成功(共{total}条)",
 | ||
|             data=FaceListResponse(
 | ||
|                 total=total,
 | ||
|                 faces=[FaceResponse(**face) for face in face_list]
 | ||
|             )
 | ||
|         )
 | ||
|     except MySQLError as e:
 | ||
|         raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
 | ||
|     finally:
 | ||
|         db.close_connection(conn, cursor)
 | ||
| 
 | ||
| 
 | ||
| # 删除人脸记录
 | ||
| @router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
 | ||
| @encrypt_response()
 | ||
| async def delete_face(face_id: int):
 | ||
|     conn = None
 | ||
|     cursor = None
 | ||
|     try:
 | ||
|         conn = db.get_connection()
 | ||
|         cursor = conn.cursor(dictionary=True)
 | ||
| 
 | ||
|         cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
 | ||
|         exist_face = cursor.fetchone()
 | ||
|         if not exist_face:
 | ||
|             raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
 | ||
|         old_db_path = exist_face["address"]
 | ||
| 
 | ||
|         cursor.execute("DELETE FROM face WHERE id = %s", (face_id,))
 | ||
|         conn.commit()
 | ||
| 
 | ||
|         # 删除图片
 | ||
|         if old_db_path:
 | ||
|             old_abs_path = Path(old_db_path).resolve()
 | ||
|             if old_abs_path.exists():
 | ||
|                 try:
 | ||
|                     old_abs_path.unlink()
 | ||
|                     print(f"[FaceRouter] 已删除图片:{old_abs_path}")
 | ||
|                     extra_msg = "(已同步删除图片)"
 | ||
|                 except Exception as e:
 | ||
|                     print(f"[FaceRouter] 删除图片失败:{str(e)}")
 | ||
|                     extra_msg = "(图片删除失败)"
 | ||
|             else:
 | ||
|                 extra_msg = "(图片不存在)"
 | ||
|         else:
 | ||
|             extra_msg = "(无关联图片)"
 | ||
| 
 | ||
| 
 | ||
|         # TODO 重新加载人脸模型
 | ||
|         update_face_data()
 | ||
| 
 | ||
|         return APIResponse(
 | ||
|             code=200,
 | ||
|             message=f"ID为 {face_id} 的记录删除成功 {extra_msg}",
 | ||
|             data=None
 | ||
|         )
 | ||
|     except MySQLError as e:
 | ||
|         if conn:
 | ||
|             conn.rollback()
 | ||
|         raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}") from e
 | ||
|     finally:
 | ||
|         db.close_connection(conn, cursor)
 | ||
| 
 | ||
| 
 | ||
| @router.post("/batch-import", response_model=APIResponse, summary="批量导入文件夹下的人脸图片")
 | ||
| # @encrypt_response()
 | ||
| async def batch_import_faces(
 | ||
|     folder_path: str = Form(..., description="人脸图片所在的**服务器本地文件夹路径**")
 | ||
| ):
 | ||
|     conn = None
 | ||
|     cursor = None
 | ||
|     success_count = 0  # 成功导入数量
 | ||
|     fail_list = []     # 失败记录(包含文件名、错误原因)
 | ||
|     try:
 | ||
|         # 1. 验证文件夹有效性
 | ||
|         folder = Path(folder_path)
 | ||
|         if not folder.exists() or not folder.is_dir():
 | ||
|             raise HTTPException(status_code=400, detail=f"文件夹 {folder_path} 不存在或不是有效目录")
 | ||
| 
 | ||
|         # 2. 定义支持的图片格式
 | ||
|         supported_extensions = {".png", ".jpg", ".jpeg", ".webp"}
 | ||
| 
 | ||
|         # 3. 数据库连接初始化
 | ||
|         conn = db.get_connection()
 | ||
|         cursor = conn.cursor(dictionary=True)
 | ||
| 
 | ||
|         # 4. 遍历文件夹内所有文件
 | ||
|         for file_path in folder.iterdir():
 | ||
|             if file_path.is_file() and file_path.suffix.lower() in supported_extensions:
 | ||
|                 file_name = file_path.stem  # 提取文件名(不含后缀)作为 `name`
 | ||
|                 try:
 | ||
|                     # 4.1 读取文件二进制内容
 | ||
|                     with open(file_path, "rb") as f:
 | ||
|                         file_content = f.read()
 | ||
| 
 | ||
|                     # 4.2 构造模拟的 UploadFile 对象(用于兼容 `save_source_file`)
 | ||
|                     mock_file = UploadFile(
 | ||
|                         filename=file_path.name,
 | ||
|                         file=BytesIO(file_content)
 | ||
|                     )
 | ||
| 
 | ||
|                     # 4.3 保存文件到指定目录
 | ||
|                     saved_path = save_source_file(mock_file, "face")
 | ||
| 
 | ||
|                     # 4.4 提取人脸特征
 | ||
|                     detect_success, detect_result = add_binary_data(file_content)
 | ||
|                     if not detect_success:
 | ||
|                         fail_list.append({
 | ||
|                             "name": file_name,
 | ||
|                             "file_path": str(file_path),
 | ||
|                             "error": f"人脸检测失败:{detect_result}"
 | ||
|                         })
 | ||
|                         continue  # 跳过当前文件,处理下一个
 | ||
|                     eigenvalue = detect_result
 | ||
| 
 | ||
|                     # 4.5 插入数据库
 | ||
|                     insert_sql = """
 | ||
|                         INSERT INTO face (name, eigenvalue, address)
 | ||
|                         VALUES (%s, %s, %s)
 | ||
|                     """
 | ||
|                     cursor.execute(insert_sql, (file_name, str(eigenvalue), saved_path))
 | ||
|                     conn.commit()  # 提交当前文件的插入操作
 | ||
| 
 | ||
|                     success_count += 1
 | ||
| 
 | ||
|                 except Exception as e:
 | ||
|                     # 捕获单文件处理的异常,记录后继续处理其他文件
 | ||
|                     fail_list.append({
 | ||
|                         "name": file_name,
 | ||
|                         "file_path": str(file_path),
 | ||
|                         "error": str(e)
 | ||
|                     })
 | ||
|                     if conn:
 | ||
|                         conn.rollback()  # 回滚当前失败文件的插入
 | ||
| 
 | ||
|         # 5. 重新加载人脸模型(确保新增数据生效)
 | ||
|         update_face_data()
 | ||
| 
 | ||
|         # 6. 构造返回结果
 | ||
|         return APIResponse(
 | ||
|             code=200,
 | ||
|             message=f"批量导入完成,成功 {success_count} 条,失败 {len(fail_list)} 条",
 | ||
|             data={
 | ||
|                 "success_count": success_count,
 | ||
|                 "fail_details": fail_list
 | ||
|             }
 | ||
|         )
 | ||
| 
 | ||
|     except MySQLError as e:
 | ||
|         if conn:
 | ||
|             conn.rollback()
 | ||
|         raise HTTPException(status_code=500, detail=f"数据库操作失败: {str(e)}") from e
 | ||
|     except HTTPException:
 | ||
|         raise  # 直接抛出400等由业务逻辑触发的HTTP异常
 | ||
|     except Exception as e:
 | ||
|         raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}") from e
 | ||
|     finally:
 | ||
|         db.close_connection(conn, cursor) |