内容安全审核

This commit is contained in:
2025-09-30 17:17:20 +08:00
commit cc6e66bbf8
523 changed files with 4853 additions and 0 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,112 @@
from fastapi import APIRouter, Query, Path
from mysql.connector import Error as MySQLError
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.device_action_schema import (
DeviceActionResponse,
DeviceActionListResponse
)
from schema.response_schema import APIResponse
# 路由配置
router = APIRouter(
prefix="/api/device/actions",
tags=["设备操作记录"]
)
@router.get("/list", response_model=APIResponse, summary="分页查询设备操作记录")
@encrypt_response()
async def get_device_action_list(
page: int = Query(1, ge=1, description="页码、默认1"),
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100"),
client_ip: str = Query(None, description="按客户端IP筛选"),
action: int = Query(None, ge=0, le=1, description="按状态筛选0=离线、1=上线)")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 构建筛选条件(参数化查询、避免注入)
where_clause = []
params = []
if client_ip:
where_clause.append("client_ip = %s")
params.append(client_ip)
if action is not None:
where_clause.append("action = %s")
params.append(action)
# 查询总记录数(用于返回 total
count_sql = "SELECT COUNT(*) AS total FROM device_action"
if where_clause:
count_sql += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_sql, params)
total = cursor.fetchone()["total"]
# 分页查询记录(按创建时间倒序、确保最新记录在前)
offset = (page - 1) * page_size
list_sql = "SELECT * FROM device_action"
if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause)
list_sql += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_sql, params)
action_list = cursor.fetchall()
# 仅返回 total + device_actions
return APIResponse(
code=200,
message="查询成功",
data=DeviceActionListResponse(
total=total,
device_actions=[DeviceActionResponse(**item) for item in action_list]
)
)
except MySQLError as e:
raise Exception(f"查询记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@router.get("/{client_ip}", response_model=APIResponse, summary="根据IP查询设备操作记录")
@encrypt_response()
async def get_device_actions_by_ip(
client_ip: str = Path(..., description="客户端IP地址")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 查询总记录数
count_sql = "SELECT COUNT(*) AS total FROM device_action WHERE client_ip = %s"
cursor.execute(count_sql, (client_ip,))
total = cursor.fetchone()["total"]
# 2. 查询该IP的所有记录按创建时间倒序
list_sql = """
SELECT * FROM device_action
WHERE client_ip = %s
ORDER BY created_at DESC
"""
cursor.execute(list_sql, (client_ip,))
action_list = cursor.fetchall()
# 3. 返回结果
return APIResponse(
code=200,
message="查询成功",
data=DeviceActionListResponse(
total=total,
device_actions=[DeviceActionResponse(**item) for item in action_list]
)
)
except MySQLError as e:
raise Exception(f"查询记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

View File

@ -0,0 +1,164 @@
from datetime import date
from fastapi import APIRouter, Query, HTTPException, Path
from mysql.connector import Error as MySQLError
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.device_danger_schema import (
DeviceDangerResponse, DeviceDangerListResponse
)
from schema.response_schema import APIResponse
router = APIRouter(
prefix="/api/devices/dangers",
tags=["设备管理-危险记录"]
)
# 获取危险记录列表
@router.get("/", response_model=APIResponse, summary="获取设备危险记录列表(多条件筛选)")
@encrypt_response()
async def get_danger_list(
page: int = Query(1, ge=1, description="页码、默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
client_ip: str = Query(None, max_length=100, description="按设备IP筛选"),
danger_type: str = Query(None, max_length=255, alias="type", description="按危险类型筛选"),
start_date: date = Query(None, description="按创建时间筛选开始日期、格式YYYY-MM-DD"),
end_date: date = Query(None, description="按创建时间筛选结束日期、格式YYYY-MM-DD")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 构建筛选条件
where_clause = []
params = []
if client_ip:
where_clause.append("client_ip = %s")
params.append(client_ip)
if danger_type:
where_clause.append("type = %s")
params.append(danger_type)
if start_date:
where_clause.append("DATE(created_at) >= %s")
params.append(start_date.strftime("%Y-%m-%d"))
if end_date:
where_clause.append("DATE(created_at) <= %s")
params.append(end_date.strftime("%Y-%m-%d"))
# 1. 统计符合条件的总记录数
count_query = "SELECT COUNT(*) AS total FROM device_danger"
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 2. 分页查询记录(按创建时间倒序、最新的在前)
offset = (page - 1) * page_size
list_query = "SELECT * FROM device_danger"
if where_clause:
list_query += " WHERE " + " AND ".join(where_clause)
list_query += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset]) # 追加分页参数
cursor.execute(list_query, params)
danger_list = cursor.fetchall()
# 转换为响应模型
return APIResponse(
code=200,
message="获取危险记录列表成功",
data=DeviceDangerListResponse(
total=total,
dangers=[DeviceDangerResponse(**item) for item in danger_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询危险记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 获取单个设备的所有危险记录
@router.get("/device/{client_ip}", response_model=APIResponse, summary="获取单个设备的所有危险记录")
# @encrypt_response()
async def get_device_dangers(
client_ip: str = Path(..., max_length=100, description="设备IP地址"),
page: int = Query(1, ge=1, description="页码、默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间")
):
# 先检查设备是否存在
from service.device_danger_service import check_device_exist
if not check_device_exist(client_ip):
raise HTTPException(status_code=404, detail=f"IP为 {client_ip} 的设备不存在")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 统计该设备的危险记录总数
count_query = "SELECT COUNT(*) AS total FROM device_danger WHERE client_ip = %s"
cursor.execute(count_query, (client_ip,))
total = cursor.fetchone()["total"]
# 2. 分页查询该设备的危险记录
offset = (page - 1) * page_size
list_query = """
SELECT * FROM device_danger
WHERE client_ip = %s
ORDER BY created_at DESC
LIMIT %s OFFSET %s
"""
cursor.execute(list_query, (client_ip, page_size, offset))
danger_list = cursor.fetchall()
return APIResponse(
code=200,
message=f"获取设备[{client_ip}]危险记录成功(共{total}条)",
data=DeviceDangerListResponse(
total=total,
dangers=[DeviceDangerResponse(**item) for item in danger_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询设备[{client_ip}]危险记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 根据ID获取单个危险记录详情
# ------------------------------
@router.get("/{danger_id}", response_model=APIResponse, summary="根据ID获取单个危险记录详情")
@encrypt_response()
async def get_danger_detail(
danger_id: int = Path(..., ge=1, description="危险记录ID")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 查询单个危险记录
query = "SELECT * FROM device_danger WHERE id = %s"
cursor.execute(query, (danger_id,))
danger = cursor.fetchone()
if not danger:
raise HTTPException(status_code=404, detail=f"ID为 {danger_id} 的危险记录不存在")
return APIResponse(
code=200,
message="获取危险记录详情成功",
data=DeviceDangerResponse(**danger)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询危险记录详情失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

329
router/device_router.py Normal file
View File

@ -0,0 +1,329 @@
import asyncio
import json
from datetime import date
from fastapi import APIRouter, Query, HTTPException, Request, Path
from mysql.connector import Error as MySQLError
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.device_schema import (
DeviceCreateRequest, DeviceResponse, DeviceListResponse,
DeviceStatusHistoryResponse, DeviceStatusHistoryListResponse
)
from schema.response_schema import APIResponse
from service.device_service import update_online_status_by_ip
from ws.ws import get_current_time_str, aes_encrypt, is_client_connected
router = APIRouter(
prefix="/api/devices",
tags=["设备管理"]
)
# 创建设备信息接口
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
@encrypt_response()
async def create_device(device_data: DeviceCreateRequest, request: Request):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查设备是否已存在
cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,))
existing_device = cursor.fetchone()
if existing_device:
# 更新设备为在线状态
from service.device_service import update_online_status_by_ip
update_online_status_by_ip(client_ip=device_data.ip, online_status=1)
return APIResponse(
code=200,
message=f"设备IP {device_data.ip} 已存在、返回已有设备信息",
data=DeviceResponse(**existing_device)
)
# 通过 User-Agent 判断设备类型
user_agent = request.headers.get("User-Agent", "").lower()
device_type = "unknown"
if user_agent == "default":
device_type = device_data.params.get("os") if (device_data.params and isinstance(device_data.params, dict)) else "unknown"
elif "windows" in user_agent:
device_type = "windows"
elif "android" in user_agent:
device_type = "android"
elif "linux" in user_agent:
device_type = "linux"
device_params_json = json.dumps(device_data.params) if device_data.params else None
# 插入新设备
insert_query = """
INSERT INTO devices
(client_ip, hostname, device_online_status, device_type, alarm_count, params, is_need_handler)
VALUES (%s, %s, %s, %s, %s, %s, %s)
"""
cursor.execute(insert_query, (
device_data.ip,
device_data.hostname,
0,
device_type,
0,
device_params_json,
0
))
conn.commit()
# 获取新设备并返回
device_id = cursor.lastrowid
cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,))
new_device = cursor.fetchone()
return APIResponse(
code=200,
message="设备创建成功",
data=DeviceResponse(**new_device)
)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"创建设备失败: {str(e)}") from e
except json.JSONDecodeError as e:
raise Exception(f"设备参数JSON序列化失败: {str(e)}") from e
except Exception as e:
if conn:
conn.rollback()
raise e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 获取设备列表接口
# ------------------------------
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
@encrypt_response()
async def get_device_list(
page: int = Query(1, ge=1, description="页码、默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
device_type: str = Query(None, description="按设备类型筛选"),
online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
where_clause = []
params = []
if device_type:
where_clause.append("device_type = %s")
params.append(device_type)
if online_status is not None:
where_clause.append("device_online_status = %s")
params.append(online_status)
# 统计总数
count_query = "SELECT COUNT(*) AS total FROM devices"
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 分页查询列表
offset = (page - 1) * page_size
list_query = "SELECT * FROM devices"
if where_clause:
list_query += " WHERE " + " AND ".join(where_clause)
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_query, params)
device_list = cursor.fetchall()
return APIResponse(
code=200,
message="获取设备列表成功",
data=DeviceListResponse(
total=total,
devices=[DeviceResponse(**device) for device in device_list]
)
)
except MySQLError as e:
raise Exception(f"获取设备列表失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 获取设备上下线记录接口
# ------------------------------
@router.get("/status-history", response_model=APIResponse, summary="获取设备上下线记录")
@encrypt_response()
async def get_device_status_history(
client_ip: str = Query(None, description="客户端IP地址非必填为空时返回所有设备记录"),
page: int = Query(1, ge=1, description="页码、默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
start_date: date = Query(None, description="开始日期、格式YYYY-MM-DD"),
end_date: date = Query(None, description="结束日期、格式YYYY-MM-DD")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 检查设备是否存在仅传IP时强制指定Collation
if client_ip is not None:
# 关键调整1WHERE条件中给d.client_ip指定Collation与da一致或反之
check_query = """
SELECT id FROM devices
WHERE client_ip COLLATE utf8mb4_general_ci = %s COLLATE utf8mb4_general_ci
"""
cursor.execute(check_query, (client_ip,))
device = cursor.fetchone()
if not device:
raise HTTPException(status_code=404, detail=f"客户端IP为 {client_ip} 的设备不存在")
# 2. 构建WHERE条件
where_clause = []
params = []
# 关键调整2传IP时强制指定da.client_ip的Collation
if client_ip is not None:
where_clause.append("da.client_ip COLLATE utf8mb4_general_ci = %s COLLATE utf8mb4_general_ci")
params.append(client_ip)
if start_date:
where_clause.append("DATE(da.created_at) >= %s")
params.append(start_date.strftime("%Y-%m-%d"))
if end_date:
where_clause.append("DATE(da.created_at) <= %s")
params.append(end_date.strftime("%Y-%m-%d"))
# 3. 统计总数JOIN时强制统一Collation
count_query = """
SELECT COUNT(*) AS total
FROM device_action da
LEFT JOIN devices d
ON da.client_ip COLLATE utf8mb4_general_ci = d.client_ip COLLATE utf8mb4_general_ci
"""
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 4. 分页查询JOIN时强制统一Collation
offset = (page - 1) * page_size
list_query = """
SELECT da.*, d.id AS device_id
FROM device_action da
LEFT JOIN devices d
ON da.client_ip COLLATE utf8mb4_general_ci = d.client_ip COLLATE utf8mb4_general_ci
"""
if where_clause:
list_query += " WHERE " + " AND ".join(where_clause)
list_query += " ORDER BY da.created_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_query, params)
history_list = cursor.fetchall()
# 后续格式化响应逻辑不变...
formatted_history = []
for item in history_list:
formatted_item = {
"id": item["id"],
"device_id": item["device_id"], # 可能为NoneIP无对应设备
"client_ip": item["client_ip"],
"status": item["action"],
"status_time": item["created_at"]
}
formatted_history.append(formatted_item)
return APIResponse(
code=200,
message="获取设备上下线记录成功",
data=DeviceStatusHistoryListResponse(
total=total,
history=[DeviceStatusHistoryResponse(**item) for item in formatted_history]
)
)
except MySQLError as e:
raise Exception(f"获取设备上下线记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 通过客户端IP设置设备is_need_handler为0接口
# ------------------------------
@router.post("/need-handler/reset", response_model=APIResponse, summary="解封客户端")
@encrypt_response()
async def reset_device_need_handler(
client_ip: str = Query(..., description="目标设备的客户端IP地址必填")
):
try:
from service.device_service import update_is_need_handler_by_client_ip
success = update_is_need_handler_by_client_ip(
client_ip=client_ip,
is_need_handler=0 # 固定设置为0不需要处理
)
if success:
online_status = is_client_connected(client_ip)
# 如果设备在线,则发送消息给前端
if online_status:
# 调用 ws 发送一个消息给前端、告诉他已解锁
unlock_msg = {
"type": "unlock",
"timestamp": get_current_time_str(),
"client_ip": client_ip
}
from ws.ws import send_message_to_client
await send_message_to_client(client_ip, json.dumps(unlock_msg))
# 休眠 100 ms
await asyncio.sleep(0.1)
frame_permit_msg = {
"type": "frame",
"timestamp": get_current_time_str(),
"client_ip": client_ip
}
await send_message_to_client(client_ip, json.dumps(frame_permit_msg))
# 更新设备在线状态为1
update_online_status_by_ip(client_ip, 1)
return APIResponse(
code=200,
message=f"设备已解封",
data={
"client_ip": client_ip,
"is_need_handler": 0,
"status_desc": "设备已解封"
}
)
# 捕获工具方法抛出的业务异常如IP为空、设备不存在
except ValueError as e:
# 业务异常返回400/404状态码与现有接口异常规范一致
raise HTTPException(
status_code=404 if "设备不存在" in str(e) else 400,
detail=str(e)
) from e
# 捕获数据库层面异常如连接失败、SQL执行错误
except MySQLError as e:
raise Exception(f"设置is_need_handler失败数据库操作异常 - {str(e)}") from e
# 捕获其他未知异常
except Exception as e:
raise Exception(f"设置is_need_handler失败未知错误 - {str(e)}") from e

326
router/face_router.py Normal file
View File

@ -0,0 +1,326 @@
from io import BytesIO
from pathlib import Path
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, Request
from mysql.connector import Error as MySQLError
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.face_schema import (
FaceCreateRequest,
FaceResponse,
FaceListResponse
)
from schema.response_schema import APIResponse
from service.face_service import update_face_data
from util.face_util import add_binary_data
from service.file_service import save_source_file
router = APIRouter(prefix="/api/faces", tags=["人脸管理"])
# 创建人脸记录
@router.post("", response_model=APIResponse, summary="创建人脸记录")
@encrypt_response()
async def create_face(
request: Request,
name: str = Form(None, max_length=255, description="名称(可选)"),
file: UploadFile = File(..., description="人脸文件(必传)")
):
conn = None
cursor = None
try:
face_create = FaceCreateRequest(name=name)
client_ip = request.client.host if request.client else ""
if not client_ip:
raise HTTPException(status_code=400, detail="无法获取客户端IP")
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 先读取文件内容
file_content = await file.read()
# 将文件指针重置到开头
await file.seek(0)
# 再保存文件
path = save_source_file(file, "face")
# 提取人脸特征
detect_success, detect_result = add_binary_data(file_content)
if not detect_success:
raise HTTPException(status_code=400, detail=f"人脸检测失败:{detect_result}")
eigenvalue = detect_result
# 插入数据库
insert_query = """
INSERT INTO face (name, eigenvalue, address)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_query, (face_create.name, str(eigenvalue), path))
conn.commit()
# 查询新记录
cursor.execute("""
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = LAST_INSERT_ID()
""")
created_face = cursor.fetchone()
if not created_face:
raise HTTPException(status_code=500, detail="创建成功但无法获取记录")
# TODO 重新加载人脸模型
update_face_data()
return APIResponse(
code=200,
message=f"人脸记录创建成功ID: {created_face['id']}",
data=FaceResponse(**created_face)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"创建失败: {str(e)}") from e
except Exception as e:
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
finally:
await file.close()
db.close_connection(conn, cursor)
# 获取单个人脸记录
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
@encrypt_response()
async def get_face(face_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = """
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = %s
"""
cursor.execute(query, (face_id,))
face = cursor.fetchone()
if not face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
return APIResponse(
code=200,
message="查询成功",
data=FaceResponse(**face)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 获取人脸列表
@router.get("", response_model=APIResponse, summary="获取人脸列表(分页+筛选)")
@encrypt_response()
async def get_face_list(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: str = Query(None),
has_eigenvalue: bool = Query(None)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
where_clause = []
params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%")
if has_eigenvalue is not None:
where_clause.append("eigenvalue IS NOT NULL" if has_eigenvalue else "eigenvalue IS NULL")
# 总记录数
count_query = "SELECT COUNT(*) AS total FROM face"
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 列表数据
offset = (page - 1) * page_size
list_query = """
SELECT id, name, address, created_at, updated_at
FROM face
"""
if where_clause:
list_query += " WHERE " + " AND ".join(where_clause)
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_query, params)
face_list = cursor.fetchall()
return APIResponse(
code=200,
message=f"获取成功(共{total}条)",
data=FaceListResponse(
total=total,
faces=[FaceResponse(**face) for face in face_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 删除人脸记录
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
@encrypt_response()
async def delete_face(face_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
exist_face = cursor.fetchone()
if not exist_face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
old_db_path = exist_face["address"]
cursor.execute("DELETE FROM face WHERE id = %s", (face_id,))
conn.commit()
# 删除图片
if old_db_path:
old_abs_path = Path(old_db_path).resolve()
if old_abs_path.exists():
try:
old_abs_path.unlink()
print(f"[FaceRouter] 已删除图片:{old_abs_path}")
extra_msg = "(已同步删除图片)"
except Exception as e:
print(f"[FaceRouter] 删除图片失败:{str(e)}")
extra_msg = "(图片删除失败)"
else:
extra_msg = "(图片不存在)"
else:
extra_msg = "(无关联图片)"
# TODO 重新加载人脸模型
update_face_data()
return APIResponse(
code=200,
message=f"ID为 {face_id} 的记录删除成功 {extra_msg}",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@router.post("/batch-import", response_model=APIResponse, summary="批量导入文件夹下的人脸图片")
# @encrypt_response()
async def batch_import_faces(
folder_path: str = Form(..., description="人脸图片所在的**服务器本地文件夹路径**")
):
conn = None
cursor = None
success_count = 0 # 成功导入数量
fail_list = [] # 失败记录(包含文件名、错误原因)
try:
# 1. 验证文件夹有效性
folder = Path(folder_path)
if not folder.exists() or not folder.is_dir():
raise HTTPException(status_code=400, detail=f"文件夹 {folder_path} 不存在或不是有效目录")
# 2. 定义支持的图片格式
supported_extensions = {".png", ".jpg", ".jpeg", ".webp"}
# 3. 数据库连接初始化
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 4. 遍历文件夹内所有文件
for file_path in folder.iterdir():
if file_path.is_file() and file_path.suffix.lower() in supported_extensions:
file_name = file_path.stem # 提取文件名(不含后缀)作为 `name`
try:
# 4.1 读取文件二进制内容
with open(file_path, "rb") as f:
file_content = f.read()
# 4.2 构造模拟的 UploadFile 对象(用于兼容 `save_source_file`
mock_file = UploadFile(
filename=file_path.name,
file=BytesIO(file_content)
)
# 4.3 保存文件到指定目录
saved_path = save_source_file(mock_file, "face")
# 4.4 提取人脸特征
detect_success, detect_result = add_binary_data(file_content)
if not detect_success:
fail_list.append({
"name": file_name,
"file_path": str(file_path),
"error": f"人脸检测失败:{detect_result}"
})
continue # 跳过当前文件,处理下一个
eigenvalue = detect_result
# 4.5 插入数据库
insert_sql = """
INSERT INTO face (name, eigenvalue, address)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_sql, (file_name, str(eigenvalue), saved_path))
conn.commit() # 提交当前文件的插入操作
success_count += 1
except Exception as e:
# 捕获单文件处理的异常,记录后继续处理其他文件
fail_list.append({
"name": file_name,
"file_path": str(file_path),
"error": str(e)
})
if conn:
conn.rollback() # 回滚当前失败文件的插入
# 5. 重新加载人脸模型(确保新增数据生效)
update_face_data()
# 6. 构造返回结果
return APIResponse(
code=200,
message=f"批量导入完成,成功 {success_count} 条,失败 {len(fail_list)}",
data={
"success_count": success_count,
"fail_details": fail_list
}
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库操作失败: {str(e)}") from e
except HTTPException:
raise # 直接抛出400等由业务逻辑触发的HTTP异常
except Exception as e:
raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

31
router/file_router.py Normal file
View File

@ -0,0 +1,31 @@
import os
from fastapi import FastAPI, HTTPException, Path, APIRouter
from fastapi.responses import FileResponse
from service.file_service import UPLOAD_ROOT
router = APIRouter(
prefix="/api/file",
tags=["文件管理"]
)
@router.get("/download/{relative_path:path}", summary="下载文件")
async def download_file(
relative_path: str = Path(..., description="文件的相对路径")
):
file_path = os.path.abspath(os.path.join(UPLOAD_ROOT, relative_path))
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail=f"文件不存在: {file_path}")
if not os.path.isfile(file_path):
raise HTTPException(status_code=400, detail="路径指向的不是文件")
if not file_path.startswith(os.path.abspath(UPLOAD_ROOT)):
raise HTTPException(status_code=403, detail="无权访问该文件")
return FileResponse(
path=file_path,
filename=os.path.basename(file_path),
media_type="application/octet-stream"
)

269
router/model_router.py Normal file
View File

@ -0,0 +1,269 @@
import os
from pathlib import Path
from service.file_service import save_source_file
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
from mysql.connector import Error as MySQLError
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.model_schema import (
ModelResponse,
ModelListResponse
)
from schema.response_schema import APIResponse
from service.model_service import ALLOWED_MODEL_EXT, MAX_MODEL_SIZE, load_yolo_model
router = APIRouter(prefix="/api/models", tags=["模型管理"])
# 上传模型
@router.post("", response_model=APIResponse, summary="上传YOLO模型.pt格式")
@encrypt_response()
async def upload_model(
name: str = Form(..., description="模型名称"),
description: str = Form(None, description="模型描述"),
file: UploadFile = File(..., description=f"YOLO模型文件.pt、最大{MAX_MODEL_SIZE // 1024 // 1024}MB")
):
conn = None
cursor = None
try:
# 校验文件格式
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else ""
if file_ext not in ALLOWED_MODEL_EXT:
raise HTTPException(
status_code=400,
detail=f"仅支持{ALLOWED_MODEL_EXT}格式、当前:{file_ext}"
)
# 校验文件大小
if file.size > MAX_MODEL_SIZE:
raise HTTPException(
status_code=400,
detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB、当前{file.size // 1024 // 1024}MB"
)
# 保存文件
file_path = save_source_file(file, "model")
# 数据库操作
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
insert_sql = """
INSERT INTO model (name, path, is_default, description, file_size)
VALUES (%s, %s, 0, %s, %s)
"""
cursor.execute(insert_sql, (name, file_path, description, file.size))
conn.commit()
# 获取新增记录
cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()")
new_model = cursor.fetchone()
if not new_model:
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
return APIResponse(
code=200,
message=f"模型上传成功",
data=ModelResponse(**new_model)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
except Exception as e:
raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e
finally:
await file.close()
db.close_connection(conn, cursor)
# 获取模型列表
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
@encrypt_response()
async def get_model_list(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: str = Query(None),
is_default: bool = Query(None)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
where_clause = []
params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%")
if is_default is not None:
where_clause.append("is_default = %s")
params.append(1 if is_default else 0)
# 总记录数
count_sql = "SELECT COUNT(*) AS total FROM model"
if where_clause:
count_sql += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_sql, params)
total = cursor.fetchone()["total"]
# 分页数据
offset = (page - 1) * page_size
list_sql = "SELECT * FROM model"
if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause)
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_sql, params)
model_list = cursor.fetchall()
return APIResponse(
code=200,
message=f"获取成功!",
data=ModelListResponse(
total=total,
models=[ModelResponse(**model) for model in model_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 更换默认模型
@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型")
@encrypt_response()
async def set_default_model(
model_id: int
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
conn.autocommit = False
# 校验目标模型是否存在
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
target_model = cursor.fetchone()
if not target_model:
raise HTTPException(status_code=404, detail=f"目标模型不存在!")
# 检查是否已为默认模型
if target_model["is_default"]:
return APIResponse(
code=200,
message=f"已是默认模型、无需更换",
data=ModelResponse(**target_model)
)
# 数据库事务:更新默认模型状态
try:
cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP")
cursor.execute(
"UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s",
(model_id,)
)
conn.commit()
except MySQLError as e:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"更新默认模型状态失败(已回滚):{str(e)}"
) from e
# 更新模型
load_yolo_model()
# 返回成功响应
return APIResponse(
code=200,
message=f"更换成功",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
if conn:
conn.autocommit = True
db.close_connection(conn, cursor)
# 路由文件(如 model_router.py中的删除接口
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
@encrypt_response()
async def delete_model(model_id: int):
# 1. 正确导入 model_service 中的全局变量(关键修复:变量名匹配)
from service.model_service import (
current_yolo_model,
current_model_absolute_path,
load_yolo_model # 用于删除后重新加载模型(可选)
)
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 2. 查询待删除模型信息
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
exist_model = cursor.fetchone()
if not exist_model:
raise HTTPException(status_code=404, detail=f"模型不存在!")
# 3. 关键判断:①默认模型不可删 ②正在使用的模型不可删
if exist_model["is_default"]:
raise HTTPException(status_code=400, detail="默认模型不可删除!")
# 计算待删除模型的绝对路径(与 model_service 逻辑一致)
from service.file_service import get_absolute_path
del_model_abs_path = get_absolute_path(exist_model["path"])
# 判断是否正在使用(对比 current_model_absolute_path
if current_model_absolute_path and del_model_abs_path == current_model_absolute_path:
raise HTTPException(status_code=400, detail="该模型正在使用中,禁止删除!")
# 4. 先删除数据库记录(避免文件删除失败导致数据不一致)
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
conn.commit()
# 5. 再删除本地文件(捕获文件删除异常,不影响数据库删除结果)
extra_msg = ""
try:
if os.path.exists(del_model_abs_path):
os.remove(del_model_abs_path) # 或用 Path(del_model_abs_path).unlink()
extra_msg = "(本地文件已同步删除)"
else:
extra_msg = "(本地文件不存在,无需删除)"
except Exception as e:
extra_msg = f"(本地文件删除失败:{str(e)}"
# 6. 若删除后当前模型为空(极端情况),重新加载默认模型(可选优化)
if current_yolo_model is None:
try:
load_yolo_model()
print(f"[模型删除后] 重新加载默认模型成功")
except Exception as e:
print(f"[模型删除后] 重新加载默认模型失败:{str(e)}")
return APIResponse(
code=200,
message=f"模型删除成功!",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)

306
router/sensitive_router.py Normal file
View File

@ -0,0 +1,306 @@
from fastapi import APIRouter, Depends, HTTPException, Query, File, UploadFile
from mysql.connector import Error as MySQLError
from typing import Optional
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.sensitive_schema import (
SensitiveCreateRequest,
SensitiveResponse,
SensitiveListResponse
)
from schema.response_schema import APIResponse
from middle.auth_middleware import get_current_user
from schema.user_schema import UserResponse
from service.ocr_service import set_forbidden_words
from service.sensitive_service import get_all_sensitive_words
router = APIRouter(
prefix="/api/sensitives",
tags=["敏感信息管理"]
)
# 创建敏感信息记录
@router.post("", response_model=APIResponse, summary="创建敏感信息记录")
@encrypt_response()
async def create_sensitive(
sensitive: SensitiveCreateRequest,
current_user: UserResponse = Depends(get_current_user)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 插入新敏感信息记录到数据库不包含ID、由数据库自动生成
insert_query = """
INSERT INTO sensitives (name, created_at, updated_at)
VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
"""
cursor.execute(insert_query, (sensitive.name,))
conn.commit()
# 获取刚插入记录的ID使用LAST_INSERT_ID()函数)
new_id = cursor.lastrowid
# 查询刚创建的记录并返回
select_query = "SELECT * FROM sensitives WHERE id = %s"
cursor.execute(select_query, (new_id,))
created_sensitive = cursor.fetchone()
# 重新加载最新的敏感词
set_forbidden_words(get_all_sensitive_words())
return APIResponse(
code=200,
message="敏感信息记录创建成功",
data=SensitiveResponse(**created_sensitive)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"创建敏感信息记录失败: {str(e)}"
) from e
finally:
db.close_connection(conn, cursor)
# 获取敏感信息分页列表
@router.get("", response_model=APIResponse, summary="获取敏感信息分页列表(支持关键词搜索)")
@encrypt_response()
async def get_sensitive_list(
page: int = Query(1, ge=1, description="页码默认1、最小1"),
page_size: int = Query(10, ge=1, le=100, description="每页条数默认10、1-100"),
name: Optional[str] = Query(None, description="敏感词关键词搜索(模糊匹配)")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 构建查询条件(支持关键词搜索)
where_clause = []
params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%") # 模糊匹配关键词
# 2. 查询总记录数(用于分页计算)
count_sql = "SELECT COUNT(*) AS total FROM sensitives"
if where_clause:
count_sql += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_sql, params.copy()) # 复制参数列表、避免后续污染
total = cursor.fetchone()["total"]
# 3. 计算分页偏移量
offset = (page - 1) * page_size
# 4. 分页查询敏感词数据(按更新时间倒序、最新的在前)
list_sql = "SELECT * FROM sensitives"
if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause)
# 排序+分页LIMIT 条数 OFFSET 偏移量)
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
# 补充分页参数page_size和offset
params.extend([page_size, offset])
cursor.execute(list_sql, params)
sensitive_list = cursor.fetchall()
# 5. 构造分页响应数据
return APIResponse(
code=200,
message=f"敏感信息列表查询成功(共{total}条记录、当前第{page}页)",
data=SensitiveListResponse(
total=total,
sensitives=[SensitiveResponse(**item) for item in sensitive_list]
)
)
except MySQLError as e:
raise HTTPException(
status_code=500,
detail=f"查询敏感信息列表失败: {str(e)}"
) from e
finally:
db.close_connection(conn, cursor)
# 删除敏感信息记录
@router.delete("/{sensitive_id}", response_model=APIResponse, summary="删除敏感信息记录")
@encrypt_response()
async def delete_sensitive(
sensitive_id: int,
current_user: UserResponse = Depends(get_current_user) # 需登录认证
):
"""
删除敏感信息记录:
- 需登录认证
- 根据ID删除敏感信息记录
- 返回删除成功信息
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 检查记录是否存在
check_query = "SELECT id FROM sensitives WHERE id = %s"
cursor.execute(check_query, (sensitive_id,))
existing_sensitive = cursor.fetchone()
if not existing_sensitive:
raise HTTPException(
status_code=404,
detail=f"ID为 {sensitive_id} 的敏感信息记录不存在"
)
# 2. 执行删除操作
delete_query = "DELETE FROM sensitives WHERE id = %s"
cursor.execute(delete_query, (sensitive_id,))
conn.commit()
# 重新加载最新的敏感词
set_forbidden_words(get_all_sensitive_words())
return APIResponse(
code=200,
message=f"ID为 {sensitive_id} 的敏感信息记录删除成功",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"删除敏感信息记录失败: {str(e)}"
) from e
finally:
db.close_connection(conn, cursor)
# 批量导入敏感信息从txt文件
@router.post("/batch-import", response_model=APIResponse, summary="批量导入敏感信息从txt文件")
@encrypt_response()
async def batch_import_sensitives(
file: UploadFile = File(..., description="包含敏感词的txt文件每行一个敏感词"),
# current_user: UserResponse = Depends(get_current_user) # 添加认证依赖
):
"""
批量导入敏感信息:
- 需登录认证
- 接收txt文件文件中每行一个敏感词
- 批量插入到数据库中(仅插入不存在的敏感词)
- 返回导入结果统计
"""
# 检查文件类型
filename = file.filename or ""
if not filename.lower().endswith(".txt"):
raise HTTPException(
status_code=400,
detail=f"请上传txt格式的文件当前文件格式: {filename.split('.')[-1] if '.' in filename else '未知'}"
)
# 检查文件大小
file_size = await file.read(1) # 读取1字节获取文件信息
await file.seek(0) # 重置文件指针
if not file_size: # 文件为空
raise HTTPException(
status_code=400,
detail="上传的文件为空,请提供有效的敏感词文件"
)
conn = None
cursor = None
try:
# 读取文件内容
contents = await file.read()
# 按行分割内容,处理不同操作系统的换行符
lines = contents.decode("utf-8", errors="replace").splitlines()
# 过滤空行和仅含空白字符的行
sensitive_words = [line.strip() for line in lines if line.strip()]
if not sensitive_words:
return APIResponse(
code=200,
message="文件中没有有效的敏感词",
data={"imported": 0, "total": 0}
)
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 先查询数据库中已存在的敏感词
query = "SELECT name FROM sensitives WHERE name IN (%s)"
# 处理参数,根据敏感词数量生成占位符
placeholders = ', '.join(['%s'] * len(sensitive_words))
cursor.execute(query % placeholders, sensitive_words)
existing_words = {row['name'] for row in cursor.fetchall()}
# 过滤掉已存在的敏感词
new_words = [word for word in sensitive_words if word not in existing_words]
if not new_words:
return APIResponse(
code=200,
message="所有敏感词均已存在于数据库中",
data={
"total": len(sensitive_words),
"imported": 0,
"duplicates": len(sensitive_words),
"message": f"共处理{len(sensitive_words)}个敏感词,全部已存在,未导入任何新敏感词"
}
)
# 批量插入新的敏感词
insert_query = """
INSERT INTO sensitives (name, created_at, updated_at)
VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
"""
# 准备参数列表
params = [(word,) for word in new_words]
# 执行批量插入
cursor.executemany(insert_query, params)
conn.commit()
# 重新加载最新的敏感词
set_forbidden_words(get_all_sensitive_words())
return APIResponse(
code=200,
message=f"敏感词批量导入成功",
data={
"total": len(sensitive_words),
"imported": len(new_words),
"duplicates": len(sensitive_words) - len(new_words),
"message": f"共处理{len(sensitive_words)}个敏感词,成功导入{len(new_words)}个,{len(sensitive_words) - len(new_words)}个已存在"
}
)
except UnicodeDecodeError:
raise HTTPException(
status_code=400,
detail="文件编码格式错误请使用UTF-8编码的txt文件"
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"批量导入敏感词失败: {str(e)}"
) from e
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"处理文件时发生错误: {str(e)}"
) from e
finally:
await file.close()
db.close_connection(conn, cursor)

246
router/user_router.py Normal file
View File

@ -0,0 +1,246 @@
from datetime import timedelta
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from mysql.connector import Error as MySQLError
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from middle.auth_middleware import (
get_password_hash,
verify_password,
create_access_token,
ACCESS_TOKEN_EXPIRE_MINUTES,
get_current_user
)
from schema.response_schema import APIResponse
from schema.user_schema import UserRegisterRequest, UserLoginRequest, UserResponse
router = APIRouter(
prefix="/api/users",
tags=["用户管理"]
)
# 用户注册接口
@router.post("/register", response_model=APIResponse, summary="用户注册")
@encrypt_response()
async def user_register(request: UserRegisterRequest):
"""
用户注册:
- 校验用户名是否已存在
- 加密密码后插入数据库
- 返回注册成功信息
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 检查用户名是否已存在(唯一索引)
check_query = "SELECT username FROM users WHERE username = %s"
cursor.execute(check_query, (request.username,))
existing_user = cursor.fetchone()
if existing_user:
raise HTTPException(
status_code=400,
detail=f"用户名 '{request.username}' 已存在、请更换其他用户名"
)
# 2. 加密密码
hashed_password = get_password_hash(request.password)
# 3. 插入新用户到数据库
insert_query = """
INSERT INTO users (username, password)
VALUES (%s, %s)
"""
cursor.execute(insert_query, (request.username, hashed_password))
conn.commit() # 提交事务
# 4. 返回注册成功响应
return APIResponse(
code=200, # 200 表示资源创建成功
message=f"用户 '{request.username}' 注册成功",
data=None
)
except MySQLError as e:
conn.rollback() # 数据库错误时回滚事务
raise Exception(f"注册失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 用户登录接口
@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token")
@encrypt_response()
async def user_login(request: UserLoginRequest):
"""
用户登录:
- 校验用户名是否存在
- 校验密码是否正确
- 生成 JWT Token 并返回
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 修复: SQL查询添加 created_at 和 updated_at 字段
query = """
SELECT id, username, password, created_at, updated_at
FROM users
WHERE username = %s
"""
cursor.execute(query, (request.username,))
user = cursor.fetchone()
# 2. 校验用户名和密码
if not user or not verify_password(request.password, user["password"]):
raise HTTPException(
status_code=401,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
# 3. 生成 Token过期时间从配置读取
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user["username"]},
expires_delta=access_token_expires
)
# 4. 返回 Token 和用户基本信息
return APIResponse(
code=200,
message="登录成功",
data={
"access_token": access_token,
"token_type": "bearer",
"user": UserResponse(
id=user["id"],
username=user["username"],
created_at=user.get("created_at"),
updated_at=user.get("updated_at")
)
}
)
except MySQLError as e:
raise Exception(f"登录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 获取当前登录用户信息(需认证)
@router.get("/me", response_model=APIResponse, summary="获取当前用户信息")
@encrypt_response()
async def get_current_user_info(
current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件
):
"""
获取当前登录用户信息:
- 需在请求头携带 Token格式: Bearer <token>
- 认证通过后返回用户信息
"""
return APIResponse(
code=200,
message="获取用户信息成功",
data=current_user
)
# 获取用户列表(仅需登录权限)
@router.get("/list", response_model=APIResponse, summary="获取用户列表")
@encrypt_response()
async def get_user_list(
page: int = Query(1, ge=1, description="页码、从1开始"),
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
username: Optional[str] = Query(None, description="用户名模糊搜索"),
current_user: UserResponse = Depends(get_current_user) # 仅需登录即可访问(移除管理员校验)
):
"""
获取用户列表:
- 需登录权限(请求头携带 Token: Bearer <token>
- 支持分页查询page=页码、page_size=每页条数)
- 支持用户名模糊搜索(如输入"test"可匹配"test123""admin_test"等)
- 仅返回用户ID、用户名、创建时间、更新时间不包含密码等敏感信息
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 计算分页偏移量page从1开始、偏移量=(页码-1*每页条数)
offset = (page - 1) * page_size
# 基础查询(仅查非敏感字段)
base_query = """
SELECT id, username, created_at, updated_at
FROM users
"""
# 总条数查询(用于分页计算)
count_query = "SELECT COUNT(*) as total FROM users"
# 条件拼接(支持用户名模糊搜索)
conditions = []
params = []
if username:
conditions.append("username LIKE %s")
params.append(f"%{username}%") # 模糊匹配:%表示任意字符
# 构建最终查询语句
if conditions:
where_clause = " WHERE " + " AND ".join(conditions)
final_query = f"{base_query}{where_clause} LIMIT %s OFFSET %s"
final_count_query = f"{count_query}{where_clause}"
params.extend([page_size, offset]) # 追加分页参数
else:
final_query = f"{base_query} LIMIT %s OFFSET %s"
final_count_query = count_query
params = [page_size, offset]
# 1. 查询用户列表数据
cursor.execute(final_query, params)
users = cursor.fetchall()
# 2. 查询总条数(用于计算总页数)
count_params = [f"%{username}%"] if username else []
cursor.execute(final_count_query, count_params)
total = cursor.fetchone()["total"]
# 3. 转换为UserResponse模型确保字段匹配
user_list = [
UserResponse(
id=user["id"],
username=user["username"],
created_at=user["created_at"],
updated_at=user["updated_at"]
)
for user in users
]
# 4. 计算总页数向上取整、如11条数据每页10条=2页
total_pages = (total + page_size - 1) // page_size
# 返回结果(包含列表和分页信息)
return APIResponse(
code=200,
message="获取用户列表成功",
data={
"users": user_list,
"pagination": {
"page": page, # 当前页码
"page_size": page_size, # 每页条数
"total": total, # 总数据量
"total_pages": total_pages # 总页数
}
}
)
except MySQLError as e:
raise Exception(f"获取用户列表失败: {str(e)}") from e
finally:
# 无论成功失败、都关闭数据库连接
db.close_connection(conn, cursor)