Files
video_detect/router/face_router.py
2025-09-30 17:17:20 +08:00

326 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from io import BytesIO
from pathlib import Path
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, Request
from mysql.connector import Error as MySQLError
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.face_schema import (
FaceCreateRequest,
FaceResponse,
FaceListResponse
)
from schema.response_schema import APIResponse
from service.face_service import update_face_data
from util.face_util import add_binary_data
from service.file_service import save_source_file
router = APIRouter(prefix="/api/faces", tags=["人脸管理"])
# 创建人脸记录
@router.post("", response_model=APIResponse, summary="创建人脸记录")
@encrypt_response()
async def create_face(
request: Request,
name: str = Form(None, max_length=255, description="名称(可选)"),
file: UploadFile = File(..., description="人脸文件(必传)")
):
conn = None
cursor = None
try:
face_create = FaceCreateRequest(name=name)
client_ip = request.client.host if request.client else ""
if not client_ip:
raise HTTPException(status_code=400, detail="无法获取客户端IP")
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 先读取文件内容
file_content = await file.read()
# 将文件指针重置到开头
await file.seek(0)
# 再保存文件
path = save_source_file(file, "face")
# 提取人脸特征
detect_success, detect_result = add_binary_data(file_content)
if not detect_success:
raise HTTPException(status_code=400, detail=f"人脸检测失败:{detect_result}")
eigenvalue = detect_result
# 插入数据库
insert_query = """
INSERT INTO face (name, eigenvalue, address)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_query, (face_create.name, str(eigenvalue), path))
conn.commit()
# 查询新记录
cursor.execute("""
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = LAST_INSERT_ID()
""")
created_face = cursor.fetchone()
if not created_face:
raise HTTPException(status_code=500, detail="创建成功但无法获取记录")
# TODO 重新加载人脸模型
update_face_data()
return APIResponse(
code=200,
message=f"人脸记录创建成功ID: {created_face['id']}",
data=FaceResponse(**created_face)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"创建失败: {str(e)}") from e
except Exception as e:
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
finally:
await file.close()
db.close_connection(conn, cursor)
# 获取单个人脸记录
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
@encrypt_response()
async def get_face(face_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = """
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = %s
"""
cursor.execute(query, (face_id,))
face = cursor.fetchone()
if not face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
return APIResponse(
code=200,
message="查询成功",
data=FaceResponse(**face)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 获取人脸列表
@router.get("", response_model=APIResponse, summary="获取人脸列表(分页+筛选)")
@encrypt_response()
async def get_face_list(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: str = Query(None),
has_eigenvalue: bool = Query(None)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
where_clause = []
params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%")
if has_eigenvalue is not None:
where_clause.append("eigenvalue IS NOT NULL" if has_eigenvalue else "eigenvalue IS NULL")
# 总记录数
count_query = "SELECT COUNT(*) AS total FROM face"
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 列表数据
offset = (page - 1) * page_size
list_query = """
SELECT id, name, address, created_at, updated_at
FROM face
"""
if where_clause:
list_query += " WHERE " + " AND ".join(where_clause)
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_query, params)
face_list = cursor.fetchall()
return APIResponse(
code=200,
message=f"获取成功(共{total}条)",
data=FaceListResponse(
total=total,
faces=[FaceResponse(**face) for face in face_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 删除人脸记录
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
@encrypt_response()
async def delete_face(face_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
exist_face = cursor.fetchone()
if not exist_face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
old_db_path = exist_face["address"]
cursor.execute("DELETE FROM face WHERE id = %s", (face_id,))
conn.commit()
# 删除图片
if old_db_path:
old_abs_path = Path(old_db_path).resolve()
if old_abs_path.exists():
try:
old_abs_path.unlink()
print(f"[FaceRouter] 已删除图片:{old_abs_path}")
extra_msg = "(已同步删除图片)"
except Exception as e:
print(f"[FaceRouter] 删除图片失败:{str(e)}")
extra_msg = "(图片删除失败)"
else:
extra_msg = "(图片不存在)"
else:
extra_msg = "(无关联图片)"
# TODO 重新加载人脸模型
update_face_data()
return APIResponse(
code=200,
message=f"ID为 {face_id} 的记录删除成功 {extra_msg}",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@router.post("/batch-import", response_model=APIResponse, summary="批量导入文件夹下的人脸图片")
# @encrypt_response()
async def batch_import_faces(
folder_path: str = Form(..., description="人脸图片所在的**服务器本地文件夹路径**")
):
conn = None
cursor = None
success_count = 0 # 成功导入数量
fail_list = [] # 失败记录(包含文件名、错误原因)
try:
# 1. 验证文件夹有效性
folder = Path(folder_path)
if not folder.exists() or not folder.is_dir():
raise HTTPException(status_code=400, detail=f"文件夹 {folder_path} 不存在或不是有效目录")
# 2. 定义支持的图片格式
supported_extensions = {".png", ".jpg", ".jpeg", ".webp"}
# 3. 数据库连接初始化
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 4. 遍历文件夹内所有文件
for file_path in folder.iterdir():
if file_path.is_file() and file_path.suffix.lower() in supported_extensions:
file_name = file_path.stem # 提取文件名(不含后缀)作为 `name`
try:
# 4.1 读取文件二进制内容
with open(file_path, "rb") as f:
file_content = f.read()
# 4.2 构造模拟的 UploadFile 对象(用于兼容 `save_source_file`
mock_file = UploadFile(
filename=file_path.name,
file=BytesIO(file_content)
)
# 4.3 保存文件到指定目录
saved_path = save_source_file(mock_file, "face")
# 4.4 提取人脸特征
detect_success, detect_result = add_binary_data(file_content)
if not detect_success:
fail_list.append({
"name": file_name,
"file_path": str(file_path),
"error": f"人脸检测失败:{detect_result}"
})
continue # 跳过当前文件,处理下一个
eigenvalue = detect_result
# 4.5 插入数据库
insert_sql = """
INSERT INTO face (name, eigenvalue, address)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_sql, (file_name, str(eigenvalue), saved_path))
conn.commit() # 提交当前文件的插入操作
success_count += 1
except Exception as e:
# 捕获单文件处理的异常,记录后继续处理其他文件
fail_list.append({
"name": file_name,
"file_path": str(file_path),
"error": str(e)
})
if conn:
conn.rollback() # 回滚当前失败文件的插入
# 5. 重新加载人脸模型(确保新增数据生效)
update_face_data()
# 6. 构造返回结果
return APIResponse(
code=200,
message=f"批量导入完成,成功 {success_count} 条,失败 {len(fail_list)}",
data={
"success_count": success_count,
"fail_details": fail_list
}
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库操作失败: {str(e)}") from e
except HTTPException:
raise # 直接抛出400等由业务逻辑触发的HTTP异常
except Exception as e:
raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}") from e
finally:
db.close_connection(conn, cursor)