Files
video_detect/router/face_router.py

326 lines
11 KiB
Python
Raw Permalink Normal View History

2025-09-30 17:17:20 +08:00
from io import BytesIO
from pathlib import Path
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, Request
from mysql.connector import Error as MySQLError
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.face_schema import (
FaceCreateRequest,
FaceResponse,
FaceListResponse
)
from schema.response_schema import APIResponse
from service.face_service import update_face_data
from util.face_util import add_binary_data
from service.file_service import save_source_file
router = APIRouter(prefix="/api/faces", tags=["人脸管理"])
# 创建人脸记录
@router.post("", response_model=APIResponse, summary="创建人脸记录")
@encrypt_response()
async def create_face(
request: Request,
name: str = Form(None, max_length=255, description="名称(可选)"),
file: UploadFile = File(..., description="人脸文件(必传)")
):
conn = None
cursor = None
try:
face_create = FaceCreateRequest(name=name)
client_ip = request.client.host if request.client else ""
if not client_ip:
raise HTTPException(status_code=400, detail="无法获取客户端IP")
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 先读取文件内容
file_content = await file.read()
# 将文件指针重置到开头
await file.seek(0)
# 再保存文件
path = save_source_file(file, "face")
# 提取人脸特征
detect_success, detect_result = add_binary_data(file_content)
if not detect_success:
raise HTTPException(status_code=400, detail=f"人脸检测失败:{detect_result}")
eigenvalue = detect_result
# 插入数据库
insert_query = """
INSERT INTO face (name, eigenvalue, address)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_query, (face_create.name, str(eigenvalue), path))
conn.commit()
# 查询新记录
cursor.execute("""
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = LAST_INSERT_ID()
""")
created_face = cursor.fetchone()
if not created_face:
raise HTTPException(status_code=500, detail="创建成功但无法获取记录")
# TODO 重新加载人脸模型
update_face_data()
return APIResponse(
code=200,
message=f"人脸记录创建成功ID: {created_face['id']}",
data=FaceResponse(**created_face)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"创建失败: {str(e)}") from e
except Exception as e:
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
finally:
await file.close()
db.close_connection(conn, cursor)
# 获取单个人脸记录
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
@encrypt_response()
async def get_face(face_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = """
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = %s
"""
cursor.execute(query, (face_id,))
face = cursor.fetchone()
if not face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
return APIResponse(
code=200,
message="查询成功",
data=FaceResponse(**face)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 获取人脸列表
@router.get("", response_model=APIResponse, summary="获取人脸列表(分页+筛选)")
@encrypt_response()
async def get_face_list(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: str = Query(None),
has_eigenvalue: bool = Query(None)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
where_clause = []
params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%")
if has_eigenvalue is not None:
where_clause.append("eigenvalue IS NOT NULL" if has_eigenvalue else "eigenvalue IS NULL")
# 总记录数
count_query = "SELECT COUNT(*) AS total FROM face"
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 列表数据
offset = (page - 1) * page_size
list_query = """
SELECT id, name, address, created_at, updated_at
FROM face
"""
if where_clause:
list_query += " WHERE " + " AND ".join(where_clause)
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_query, params)
face_list = cursor.fetchall()
return APIResponse(
code=200,
message=f"获取成功(共{total}条)",
data=FaceListResponse(
total=total,
faces=[FaceResponse(**face) for face in face_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 删除人脸记录
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
@encrypt_response()
async def delete_face(face_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
exist_face = cursor.fetchone()
if not exist_face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
old_db_path = exist_face["address"]
cursor.execute("DELETE FROM face WHERE id = %s", (face_id,))
conn.commit()
# 删除图片
if old_db_path:
old_abs_path = Path(old_db_path).resolve()
if old_abs_path.exists():
try:
old_abs_path.unlink()
print(f"[FaceRouter] 已删除图片:{old_abs_path}")
extra_msg = "(已同步删除图片)"
except Exception as e:
print(f"[FaceRouter] 删除图片失败:{str(e)}")
extra_msg = "(图片删除失败)"
else:
extra_msg = "(图片不存在)"
else:
extra_msg = "(无关联图片)"
# TODO 重新加载人脸模型
update_face_data()
return APIResponse(
code=200,
message=f"ID为 {face_id} 的记录删除成功 {extra_msg}",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@router.post("/batch-import", response_model=APIResponse, summary="批量导入文件夹下的人脸图片")
# @encrypt_response()
async def batch_import_faces(
folder_path: str = Form(..., description="人脸图片所在的**服务器本地文件夹路径**")
):
conn = None
cursor = None
success_count = 0 # 成功导入数量
fail_list = [] # 失败记录(包含文件名、错误原因)
try:
# 1. 验证文件夹有效性
folder = Path(folder_path)
if not folder.exists() or not folder.is_dir():
raise HTTPException(status_code=400, detail=f"文件夹 {folder_path} 不存在或不是有效目录")
# 2. 定义支持的图片格式
supported_extensions = {".png", ".jpg", ".jpeg", ".webp"}
# 3. 数据库连接初始化
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 4. 遍历文件夹内所有文件
for file_path in folder.iterdir():
if file_path.is_file() and file_path.suffix.lower() in supported_extensions:
file_name = file_path.stem # 提取文件名(不含后缀)作为 `name`
try:
# 4.1 读取文件二进制内容
with open(file_path, "rb") as f:
file_content = f.read()
# 4.2 构造模拟的 UploadFile 对象(用于兼容 `save_source_file`
mock_file = UploadFile(
filename=file_path.name,
file=BytesIO(file_content)
)
# 4.3 保存文件到指定目录
saved_path = save_source_file(mock_file, "face")
# 4.4 提取人脸特征
detect_success, detect_result = add_binary_data(file_content)
if not detect_success:
fail_list.append({
"name": file_name,
"file_path": str(file_path),
"error": f"人脸检测失败:{detect_result}"
})
continue # 跳过当前文件,处理下一个
eigenvalue = detect_result
# 4.5 插入数据库
insert_sql = """
INSERT INTO face (name, eigenvalue, address)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_sql, (file_name, str(eigenvalue), saved_path))
conn.commit() # 提交当前文件的插入操作
success_count += 1
except Exception as e:
# 捕获单文件处理的异常,记录后继续处理其他文件
fail_list.append({
"name": file_name,
"file_path": str(file_path),
"error": str(e)
})
if conn:
conn.rollback() # 回滚当前失败文件的插入
# 5. 重新加载人脸模型(确保新增数据生效)
update_face_data()
# 6. 构造返回结果
return APIResponse(
code=200,
message=f"批量导入完成,成功 {success_count} 条,失败 {len(fail_list)}",
data={
"success_count": success_count,
"fail_details": fail_list
}
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库操作失败: {str(e)}") from e
except HTTPException:
raise # 直接抛出400等由业务逻辑触发的HTTP异常
except Exception as e:
raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}") from e
finally:
db.close_connection(conn, cursor)