326 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			326 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | from io import BytesIO | |||
|  | from pathlib import Path | |||
|  | 
 | |||
|  | from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, Request | |||
|  | from mysql.connector import Error as MySQLError | |||
|  | 
 | |||
|  | from ds.db import db | |||
|  | from encryption.encrypt_decorator import encrypt_response | |||
|  | from schema.face_schema import ( | |||
|  |     FaceCreateRequest, | |||
|  |     FaceResponse, | |||
|  |     FaceListResponse | |||
|  | ) | |||
|  | from schema.response_schema import APIResponse | |||
|  | from service.face_service import update_face_data | |||
|  | from util.face_util import add_binary_data | |||
|  | from service.file_service import save_source_file | |||
|  | 
 | |||
|  | router = APIRouter(prefix="/api/faces", tags=["人脸管理"]) | |||
|  | 
 | |||
|  | 
 | |||
|  | # 创建人脸记录 | |||
|  | @router.post("", response_model=APIResponse, summary="创建人脸记录") | |||
|  | @encrypt_response() | |||
|  | async def create_face( | |||
|  |         request: Request, | |||
|  |         name: str = Form(None, max_length=255, description="名称(可选)"), | |||
|  |         file: UploadFile = File(..., description="人脸文件(必传)") | |||
|  | ): | |||
|  |     conn = None | |||
|  |     cursor = None | |||
|  |     try: | |||
|  |         face_create = FaceCreateRequest(name=name) | |||
|  |         client_ip = request.client.host if request.client else "" | |||
|  |         if not client_ip: | |||
|  |             raise HTTPException(status_code=400, detail="无法获取客户端IP") | |||
|  | 
 | |||
|  |         conn = db.get_connection() | |||
|  |         cursor = conn.cursor(dictionary=True) | |||
|  |         # 先读取文件内容 | |||
|  |         file_content = await file.read() | |||
|  |         # 将文件指针重置到开头 | |||
|  |         await file.seek(0) | |||
|  |         # 再保存文件 | |||
|  |         path = save_source_file(file, "face") | |||
|  |         # 提取人脸特征 | |||
|  |         detect_success, detect_result = add_binary_data(file_content) | |||
|  |         if not detect_success: | |||
|  |             raise HTTPException(status_code=400, detail=f"人脸检测失败:{detect_result}") | |||
|  |         eigenvalue = detect_result | |||
|  | 
 | |||
|  |         # 插入数据库 | |||
|  |         insert_query = """
 | |||
|  |             INSERT INTO face (name, eigenvalue, address) | |||
|  |             VALUES (%s, %s, %s) | |||
|  |         """
 | |||
|  |         cursor.execute(insert_query, (face_create.name, str(eigenvalue), path)) | |||
|  |         conn.commit() | |||
|  | 
 | |||
|  |         # 查询新记录 | |||
|  |         cursor.execute("""
 | |||
|  |             SELECT id, name, address, created_at, updated_at  | |||
|  |             FROM face  | |||
|  |             WHERE id = LAST_INSERT_ID() | |||
|  |         """)
 | |||
|  |         created_face = cursor.fetchone() | |||
|  |         if not created_face: | |||
|  |             raise HTTPException(status_code=500, detail="创建成功但无法获取记录") | |||
|  | 
 | |||
|  | 
 | |||
|  |         # TODO 重新加载人脸模型 | |||
|  |         update_face_data() | |||
|  | 
 | |||
|  | 
 | |||
|  |         return APIResponse( | |||
|  |             code=200, | |||
|  |             message=f"人脸记录创建成功(ID: {created_face['id']})", | |||
|  |             data=FaceResponse(**created_face) | |||
|  |         ) | |||
|  |     except MySQLError as e: | |||
|  |         if conn: | |||
|  |             conn.rollback() | |||
|  |         raise HTTPException(status_code=500, detail=f"创建失败: {str(e)}") from e | |||
|  |     except Exception as e: | |||
|  |         raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e | |||
|  |     finally: | |||
|  |         await file.close() | |||
|  |         db.close_connection(conn, cursor) | |||
|  | 
 | |||
|  | 
 | |||
|  | # 获取单个人脸记录 | |||
|  | @router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录") | |||
|  | @encrypt_response() | |||
|  | async def get_face(face_id: int): | |||
|  |     conn = None | |||
|  |     cursor = None | |||
|  |     try: | |||
|  |         conn = db.get_connection() | |||
|  |         cursor = conn.cursor(dictionary=True) | |||
|  | 
 | |||
|  |         query = """
 | |||
|  |             SELECT id, name, address, created_at, updated_at  | |||
|  |             FROM face  | |||
|  |             WHERE id = %s | |||
|  |         """
 | |||
|  |         cursor.execute(query, (face_id,)) | |||
|  |         face = cursor.fetchone() | |||
|  | 
 | |||
|  |         if not face: | |||
|  |             raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在") | |||
|  | 
 | |||
|  |         return APIResponse( | |||
|  |             code=200, | |||
|  |             message="查询成功", | |||
|  |             data=FaceResponse(**face) | |||
|  |         ) | |||
|  |     except MySQLError as e: | |||
|  |         raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e | |||
|  |     finally: | |||
|  |         db.close_connection(conn, cursor) | |||
|  | 
 | |||
|  | 
 | |||
|  | # 获取人脸列表 | |||
|  | @router.get("", response_model=APIResponse, summary="获取人脸列表(分页+筛选)") | |||
|  | @encrypt_response() | |||
|  | async def get_face_list( | |||
|  |         page: int = Query(1, ge=1), | |||
|  |         page_size: int = Query(10, ge=1, le=100), | |||
|  |         name: str = Query(None), | |||
|  |         has_eigenvalue: bool = Query(None) | |||
|  | ): | |||
|  |     conn = None | |||
|  |     cursor = None | |||
|  |     try: | |||
|  |         conn = db.get_connection() | |||
|  |         cursor = conn.cursor(dictionary=True) | |||
|  | 
 | |||
|  |         where_clause = [] | |||
|  |         params = [] | |||
|  |         if name: | |||
|  |             where_clause.append("name LIKE %s") | |||
|  |             params.append(f"%{name}%") | |||
|  |         if has_eigenvalue is not None: | |||
|  |             where_clause.append("eigenvalue IS NOT NULL" if has_eigenvalue else "eigenvalue IS NULL") | |||
|  | 
 | |||
|  |         # 总记录数 | |||
|  |         count_query = "SELECT COUNT(*) AS total FROM face" | |||
|  |         if where_clause: | |||
|  |             count_query += " WHERE " + " AND ".join(where_clause) | |||
|  |         cursor.execute(count_query, params) | |||
|  |         total = cursor.fetchone()["total"] | |||
|  | 
 | |||
|  |         # 列表数据 | |||
|  |         offset = (page - 1) * page_size | |||
|  |         list_query = """
 | |||
|  |             SELECT id, name, address, created_at, updated_at  | |||
|  |             FROM face | |||
|  |         """
 | |||
|  |         if where_clause: | |||
|  |             list_query += " WHERE " + " AND ".join(where_clause) | |||
|  |         list_query += " ORDER BY id DESC LIMIT %s OFFSET %s" | |||
|  |         params.extend([page_size, offset]) | |||
|  | 
 | |||
|  |         cursor.execute(list_query, params) | |||
|  |         face_list = cursor.fetchall() | |||
|  | 
 | |||
|  |         return APIResponse( | |||
|  |             code=200, | |||
|  |             message=f"获取成功(共{total}条)", | |||
|  |             data=FaceListResponse( | |||
|  |                 total=total, | |||
|  |                 faces=[FaceResponse(**face) for face in face_list] | |||
|  |             ) | |||
|  |         ) | |||
|  |     except MySQLError as e: | |||
|  |         raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e | |||
|  |     finally: | |||
|  |         db.close_connection(conn, cursor) | |||
|  | 
 | |||
|  | 
 | |||
|  | # 删除人脸记录 | |||
|  | @router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录") | |||
|  | @encrypt_response() | |||
|  | async def delete_face(face_id: int): | |||
|  |     conn = None | |||
|  |     cursor = None | |||
|  |     try: | |||
|  |         conn = db.get_connection() | |||
|  |         cursor = conn.cursor(dictionary=True) | |||
|  | 
 | |||
|  |         cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,)) | |||
|  |         exist_face = cursor.fetchone() | |||
|  |         if not exist_face: | |||
|  |             raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在") | |||
|  |         old_db_path = exist_face["address"] | |||
|  | 
 | |||
|  |         cursor.execute("DELETE FROM face WHERE id = %s", (face_id,)) | |||
|  |         conn.commit() | |||
|  | 
 | |||
|  |         # 删除图片 | |||
|  |         if old_db_path: | |||
|  |             old_abs_path = Path(old_db_path).resolve() | |||
|  |             if old_abs_path.exists(): | |||
|  |                 try: | |||
|  |                     old_abs_path.unlink() | |||
|  |                     print(f"[FaceRouter] 已删除图片:{old_abs_path}") | |||
|  |                     extra_msg = "(已同步删除图片)" | |||
|  |                 except Exception as e: | |||
|  |                     print(f"[FaceRouter] 删除图片失败:{str(e)}") | |||
|  |                     extra_msg = "(图片删除失败)" | |||
|  |             else: | |||
|  |                 extra_msg = "(图片不存在)" | |||
|  |         else: | |||
|  |             extra_msg = "(无关联图片)" | |||
|  | 
 | |||
|  | 
 | |||
|  |         # TODO 重新加载人脸模型 | |||
|  |         update_face_data() | |||
|  | 
 | |||
|  |         return APIResponse( | |||
|  |             code=200, | |||
|  |             message=f"ID为 {face_id} 的记录删除成功 {extra_msg}", | |||
|  |             data=None | |||
|  |         ) | |||
|  |     except MySQLError as e: | |||
|  |         if conn: | |||
|  |             conn.rollback() | |||
|  |         raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}") from e | |||
|  |     finally: | |||
|  |         db.close_connection(conn, cursor) | |||
|  | 
 | |||
|  | 
 | |||
|  | @router.post("/batch-import", response_model=APIResponse, summary="批量导入文件夹下的人脸图片") | |||
|  | # @encrypt_response() | |||
|  | async def batch_import_faces( | |||
|  |     folder_path: str = Form(..., description="人脸图片所在的**服务器本地文件夹路径**") | |||
|  | ): | |||
|  |     conn = None | |||
|  |     cursor = None | |||
|  |     success_count = 0  # 成功导入数量 | |||
|  |     fail_list = []     # 失败记录(包含文件名、错误原因) | |||
|  |     try: | |||
|  |         # 1. 验证文件夹有效性 | |||
|  |         folder = Path(folder_path) | |||
|  |         if not folder.exists() or not folder.is_dir(): | |||
|  |             raise HTTPException(status_code=400, detail=f"文件夹 {folder_path} 不存在或不是有效目录") | |||
|  | 
 | |||
|  |         # 2. 定义支持的图片格式 | |||
|  |         supported_extensions = {".png", ".jpg", ".jpeg", ".webp"} | |||
|  | 
 | |||
|  |         # 3. 数据库连接初始化 | |||
|  |         conn = db.get_connection() | |||
|  |         cursor = conn.cursor(dictionary=True) | |||
|  | 
 | |||
|  |         # 4. 遍历文件夹内所有文件 | |||
|  |         for file_path in folder.iterdir(): | |||
|  |             if file_path.is_file() and file_path.suffix.lower() in supported_extensions: | |||
|  |                 file_name = file_path.stem  # 提取文件名(不含后缀)作为 `name` | |||
|  |                 try: | |||
|  |                     # 4.1 读取文件二进制内容 | |||
|  |                     with open(file_path, "rb") as f: | |||
|  |                         file_content = f.read() | |||
|  | 
 | |||
|  |                     # 4.2 构造模拟的 UploadFile 对象(用于兼容 `save_source_file`) | |||
|  |                     mock_file = UploadFile( | |||
|  |                         filename=file_path.name, | |||
|  |                         file=BytesIO(file_content) | |||
|  |                     ) | |||
|  | 
 | |||
|  |                     # 4.3 保存文件到指定目录 | |||
|  |                     saved_path = save_source_file(mock_file, "face") | |||
|  | 
 | |||
|  |                     # 4.4 提取人脸特征 | |||
|  |                     detect_success, detect_result = add_binary_data(file_content) | |||
|  |                     if not detect_success: | |||
|  |                         fail_list.append({ | |||
|  |                             "name": file_name, | |||
|  |                             "file_path": str(file_path), | |||
|  |                             "error": f"人脸检测失败:{detect_result}" | |||
|  |                         }) | |||
|  |                         continue  # 跳过当前文件,处理下一个 | |||
|  |                     eigenvalue = detect_result | |||
|  | 
 | |||
|  |                     # 4.5 插入数据库 | |||
|  |                     insert_sql = """
 | |||
|  |                         INSERT INTO face (name, eigenvalue, address) | |||
|  |                         VALUES (%s, %s, %s) | |||
|  |                     """
 | |||
|  |                     cursor.execute(insert_sql, (file_name, str(eigenvalue), saved_path)) | |||
|  |                     conn.commit()  # 提交当前文件的插入操作 | |||
|  | 
 | |||
|  |                     success_count += 1 | |||
|  | 
 | |||
|  |                 except Exception as e: | |||
|  |                     # 捕获单文件处理的异常,记录后继续处理其他文件 | |||
|  |                     fail_list.append({ | |||
|  |                         "name": file_name, | |||
|  |                         "file_path": str(file_path), | |||
|  |                         "error": str(e) | |||
|  |                     }) | |||
|  |                     if conn: | |||
|  |                         conn.rollback()  # 回滚当前失败文件的插入 | |||
|  | 
 | |||
|  |         # 5. 重新加载人脸模型(确保新增数据生效) | |||
|  |         update_face_data() | |||
|  | 
 | |||
|  |         # 6. 构造返回结果 | |||
|  |         return APIResponse( | |||
|  |             code=200, | |||
|  |             message=f"批量导入完成,成功 {success_count} 条,失败 {len(fail_list)} 条", | |||
|  |             data={ | |||
|  |                 "success_count": success_count, | |||
|  |                 "fail_details": fail_list | |||
|  |             } | |||
|  |         ) | |||
|  | 
 | |||
|  |     except MySQLError as e: | |||
|  |         if conn: | |||
|  |             conn.rollback() | |||
|  |         raise HTTPException(status_code=500, detail=f"数据库操作失败: {str(e)}") from e | |||
|  |     except HTTPException: | |||
|  |         raise  # 直接抛出400等由业务逻辑触发的HTTP异常 | |||
|  |     except Exception as e: | |||
|  |         raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}") from e | |||
|  |     finally: | |||
|  |         db.close_connection(conn, cursor) |