内容安全审核
This commit is contained in:
BIN
router/__pycache__/device_danger_router.cpython-310.pyc
Normal file
BIN
router/__pycache__/device_danger_router.cpython-310.pyc
Normal file
Binary file not shown.
BIN
router/__pycache__/device_router.cpython-310.pyc
Normal file
BIN
router/__pycache__/device_router.cpython-310.pyc
Normal file
Binary file not shown.
BIN
router/__pycache__/face_router.cpython-310.pyc
Normal file
BIN
router/__pycache__/face_router.cpython-310.pyc
Normal file
Binary file not shown.
BIN
router/__pycache__/file_router.cpython-310.pyc
Normal file
BIN
router/__pycache__/file_router.cpython-310.pyc
Normal file
Binary file not shown.
BIN
router/__pycache__/model_router.cpython-310.pyc
Normal file
BIN
router/__pycache__/model_router.cpython-310.pyc
Normal file
Binary file not shown.
BIN
router/__pycache__/sensitive_router.cpython-310.pyc
Normal file
BIN
router/__pycache__/sensitive_router.cpython-310.pyc
Normal file
Binary file not shown.
BIN
router/__pycache__/user_router.cpython-310.pyc
Normal file
BIN
router/__pycache__/user_router.cpython-310.pyc
Normal file
Binary file not shown.
112
router/device_action_router.py
Normal file
112
router/device_action_router.py
Normal file
@ -0,0 +1,112 @@
|
||||
from fastapi import APIRouter, Query, Path
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.device_action_schema import (
|
||||
DeviceActionResponse,
|
||||
DeviceActionListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
|
||||
# 路由配置
|
||||
router = APIRouter(
|
||||
prefix="/api/device/actions",
|
||||
tags=["设备操作记录"]
|
||||
)
|
||||
|
||||
@router.get("/list", response_model=APIResponse, summary="分页查询设备操作记录")
|
||||
@encrypt_response()
|
||||
async def get_device_action_list(
|
||||
page: int = Query(1, ge=1, description="页码、默认1"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100"),
|
||||
client_ip: str = Query(None, description="按客户端IP筛选"),
|
||||
action: int = Query(None, ge=0, le=1, description="按状态筛选(0=离线、1=上线)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
# 构建筛选条件(参数化查询、避免注入)
|
||||
where_clause = []
|
||||
params = []
|
||||
if client_ip:
|
||||
where_clause.append("client_ip = %s")
|
||||
params.append(client_ip)
|
||||
if action is not None:
|
||||
where_clause.append("action = %s")
|
||||
params.append(action)
|
||||
|
||||
# 查询总记录数(用于返回 total)
|
||||
count_sql = "SELECT COUNT(*) AS total FROM device_action"
|
||||
if where_clause:
|
||||
count_sql += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_sql, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 分页查询记录(按创建时间倒序、确保最新记录在前)
|
||||
offset = (page - 1) * page_size
|
||||
list_sql = "SELECT * FROM device_action"
|
||||
if where_clause:
|
||||
list_sql += " WHERE " + " AND ".join(where_clause)
|
||||
list_sql += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(list_sql, params)
|
||||
action_list = cursor.fetchall()
|
||||
|
||||
# 仅返回 total + device_actions
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="查询成功",
|
||||
data=DeviceActionListResponse(
|
||||
total=total,
|
||||
device_actions=[DeviceActionResponse(**item) for item in action_list]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
@router.get("/{client_ip}", response_model=APIResponse, summary="根据IP查询设备操作记录")
|
||||
@encrypt_response()
|
||||
async def get_device_actions_by_ip(
|
||||
client_ip: str = Path(..., description="客户端IP地址")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 查询总记录数
|
||||
count_sql = "SELECT COUNT(*) AS total FROM device_action WHERE client_ip = %s"
|
||||
cursor.execute(count_sql, (client_ip,))
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 2. 查询该IP的所有记录(按创建时间倒序)
|
||||
list_sql = """
|
||||
SELECT * FROM device_action
|
||||
WHERE client_ip = %s
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
cursor.execute(list_sql, (client_ip,))
|
||||
action_list = cursor.fetchall()
|
||||
|
||||
# 3. 返回结果
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="查询成功",
|
||||
data=DeviceActionListResponse(
|
||||
total=total,
|
||||
device_actions=[DeviceActionResponse(**item) for item in action_list]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
164
router/device_danger_router.py
Normal file
164
router/device_danger_router.py
Normal file
@ -0,0 +1,164 @@
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Query, HTTPException, Path
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.device_danger_schema import (
|
||||
DeviceDangerResponse, DeviceDangerListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/devices/dangers",
|
||||
tags=["设备管理-危险记录"]
|
||||
)
|
||||
|
||||
# 获取危险记录列表
|
||||
@router.get("/", response_model=APIResponse, summary="获取设备危险记录列表(多条件筛选)")
|
||||
@encrypt_response()
|
||||
async def get_danger_list(
|
||||
page: int = Query(1, ge=1, description="页码、默认第1页"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
|
||||
client_ip: str = Query(None, max_length=100, description="按设备IP筛选"),
|
||||
danger_type: str = Query(None, max_length=255, alias="type", description="按危险类型筛选"),
|
||||
start_date: date = Query(None, description="按创建时间筛选(开始日期、格式YYYY-MM-DD)"),
|
||||
end_date: date = Query(None, description="按创建时间筛选(结束日期、格式YYYY-MM-DD)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 构建筛选条件
|
||||
where_clause = []
|
||||
params = []
|
||||
|
||||
if client_ip:
|
||||
where_clause.append("client_ip = %s")
|
||||
params.append(client_ip)
|
||||
if danger_type:
|
||||
where_clause.append("type = %s")
|
||||
params.append(danger_type)
|
||||
if start_date:
|
||||
where_clause.append("DATE(created_at) >= %s")
|
||||
params.append(start_date.strftime("%Y-%m-%d"))
|
||||
if end_date:
|
||||
where_clause.append("DATE(created_at) <= %s")
|
||||
params.append(end_date.strftime("%Y-%m-%d"))
|
||||
|
||||
# 1. 统计符合条件的总记录数
|
||||
count_query = "SELECT COUNT(*) AS total FROM device_danger"
|
||||
if where_clause:
|
||||
count_query += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_query, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 2. 分页查询记录(按创建时间倒序、最新的在前)
|
||||
offset = (page - 1) * page_size
|
||||
list_query = "SELECT * FROM device_danger"
|
||||
if where_clause:
|
||||
list_query += " WHERE " + " AND ".join(where_clause)
|
||||
list_query += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset]) # 追加分页参数
|
||||
|
||||
cursor.execute(list_query, params)
|
||||
danger_list = cursor.fetchall()
|
||||
|
||||
# 转换为响应模型
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取危险记录列表成功",
|
||||
data=DeviceDangerListResponse(
|
||||
total=total,
|
||||
dangers=[DeviceDangerResponse(**item) for item in danger_list]
|
||||
)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询危险记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 获取单个设备的所有危险记录
|
||||
@router.get("/device/{client_ip}", response_model=APIResponse, summary="获取单个设备的所有危险记录")
|
||||
# @encrypt_response()
|
||||
async def get_device_dangers(
|
||||
client_ip: str = Path(..., max_length=100, description="设备IP地址"),
|
||||
page: int = Query(1, ge=1, description="页码、默认第1页"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间")
|
||||
):
|
||||
# 先检查设备是否存在
|
||||
from service.device_danger_service import check_device_exist
|
||||
if not check_device_exist(client_ip):
|
||||
raise HTTPException(status_code=404, detail=f"IP为 {client_ip} 的设备不存在")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 统计该设备的危险记录总数
|
||||
count_query = "SELECT COUNT(*) AS total FROM device_danger WHERE client_ip = %s"
|
||||
cursor.execute(count_query, (client_ip,))
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 2. 分页查询该设备的危险记录
|
||||
offset = (page - 1) * page_size
|
||||
list_query = """
|
||||
SELECT * FROM device_danger
|
||||
WHERE client_ip = %s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
"""
|
||||
cursor.execute(list_query, (client_ip, page_size, offset))
|
||||
danger_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"获取设备[{client_ip}]危险记录成功(共{total}条)",
|
||||
data=DeviceDangerListResponse(
|
||||
total=total,
|
||||
dangers=[DeviceDangerResponse(**item) for item in danger_list]
|
||||
)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询设备[{client_ip}]危险记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 根据ID获取单个危险记录详情
|
||||
# ------------------------------
|
||||
@router.get("/{danger_id}", response_model=APIResponse, summary="根据ID获取单个危险记录详情")
|
||||
@encrypt_response()
|
||||
async def get_danger_detail(
|
||||
danger_id: int = Path(..., ge=1, description="危险记录ID")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 查询单个危险记录
|
||||
query = "SELECT * FROM device_danger WHERE id = %s"
|
||||
cursor.execute(query, (danger_id,))
|
||||
danger = cursor.fetchone()
|
||||
|
||||
if not danger:
|
||||
raise HTTPException(status_code=404, detail=f"ID为 {danger_id} 的危险记录不存在")
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取危险记录详情成功",
|
||||
data=DeviceDangerResponse(**danger)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询危险记录详情失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
329
router/device_router.py
Normal file
329
router/device_router.py
Normal file
@ -0,0 +1,329 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Query, HTTPException, Request, Path
|
||||
from mysql.connector import Error as MySQLError
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.device_schema import (
|
||||
DeviceCreateRequest, DeviceResponse, DeviceListResponse,
|
||||
DeviceStatusHistoryResponse, DeviceStatusHistoryListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
from service.device_service import update_online_status_by_ip
|
||||
from ws.ws import get_current_time_str, aes_encrypt, is_client_connected
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/devices",
|
||||
tags=["设备管理"]
|
||||
)
|
||||
|
||||
|
||||
# 创建设备信息接口
|
||||
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
|
||||
@encrypt_response()
|
||||
async def create_device(device_data: DeviceCreateRequest, request: Request):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查设备是否已存在
|
||||
cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,))
|
||||
existing_device = cursor.fetchone()
|
||||
if existing_device:
|
||||
# 更新设备为在线状态
|
||||
from service.device_service import update_online_status_by_ip
|
||||
update_online_status_by_ip(client_ip=device_data.ip, online_status=1)
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"设备IP {device_data.ip} 已存在、返回已有设备信息",
|
||||
data=DeviceResponse(**existing_device)
|
||||
)
|
||||
|
||||
# 通过 User-Agent 判断设备类型
|
||||
user_agent = request.headers.get("User-Agent", "").lower()
|
||||
device_type = "unknown"
|
||||
if user_agent == "default":
|
||||
device_type = device_data.params.get("os") if (device_data.params and isinstance(device_data.params, dict)) else "unknown"
|
||||
elif "windows" in user_agent:
|
||||
device_type = "windows"
|
||||
elif "android" in user_agent:
|
||||
device_type = "android"
|
||||
elif "linux" in user_agent:
|
||||
device_type = "linux"
|
||||
|
||||
device_params_json = json.dumps(device_data.params) if device_data.params else None
|
||||
|
||||
# 插入新设备
|
||||
insert_query = """
|
||||
INSERT INTO devices
|
||||
(client_ip, hostname, device_online_status, device_type, alarm_count, params, is_need_handler)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_query, (
|
||||
device_data.ip,
|
||||
device_data.hostname,
|
||||
0,
|
||||
device_type,
|
||||
0,
|
||||
device_params_json,
|
||||
0
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
# 获取新设备并返回
|
||||
device_id = cursor.lastrowid
|
||||
cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,))
|
||||
new_device = cursor.fetchone()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="设备创建成功",
|
||||
data=DeviceResponse(**new_device)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"创建设备失败: {str(e)}") from e
|
||||
except json.JSONDecodeError as e:
|
||||
raise Exception(f"设备参数JSON序列化失败: {str(e)}") from e
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 获取设备列表接口
|
||||
# ------------------------------
|
||||
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
|
||||
@encrypt_response()
|
||||
async def get_device_list(
|
||||
page: int = Query(1, ge=1, description="页码、默认第1页"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
|
||||
device_type: str = Query(None, description="按设备类型筛选"),
|
||||
online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
where_clause = []
|
||||
params = []
|
||||
|
||||
if device_type:
|
||||
where_clause.append("device_type = %s")
|
||||
params.append(device_type)
|
||||
if online_status is not None:
|
||||
where_clause.append("device_online_status = %s")
|
||||
params.append(online_status)
|
||||
|
||||
# 统计总数
|
||||
count_query = "SELECT COUNT(*) AS total FROM devices"
|
||||
if where_clause:
|
||||
count_query += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_query, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 分页查询列表
|
||||
offset = (page - 1) * page_size
|
||||
list_query = "SELECT * FROM devices"
|
||||
if where_clause:
|
||||
list_query += " WHERE " + " AND ".join(where_clause)
|
||||
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(list_query, params)
|
||||
device_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取设备列表成功",
|
||||
data=DeviceListResponse(
|
||||
total=total,
|
||||
devices=[DeviceResponse(**device) for device in device_list]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取设备列表失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 获取设备上下线记录接口
|
||||
# ------------------------------
|
||||
@router.get("/status-history", response_model=APIResponse, summary="获取设备上下线记录")
|
||||
@encrypt_response()
|
||||
async def get_device_status_history(
|
||||
client_ip: str = Query(None, description="客户端IP地址(非必填,为空时返回所有设备记录)"),
|
||||
page: int = Query(1, ge=1, description="页码、默认第1页"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
|
||||
start_date: date = Query(None, description="开始日期、格式YYYY-MM-DD"),
|
||||
end_date: date = Query(None, description="结束日期、格式YYYY-MM-DD")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 检查设备是否存在(仅传IP时):强制指定Collation
|
||||
if client_ip is not None:
|
||||
# 关键调整1:WHERE条件中给d.client_ip指定Collation(与da一致或反之)
|
||||
check_query = """
|
||||
SELECT id FROM devices
|
||||
WHERE client_ip COLLATE utf8mb4_general_ci = %s COLLATE utf8mb4_general_ci
|
||||
"""
|
||||
cursor.execute(check_query, (client_ip,))
|
||||
device = cursor.fetchone()
|
||||
if not device:
|
||||
raise HTTPException(status_code=404, detail=f"客户端IP为 {client_ip} 的设备不存在")
|
||||
|
||||
# 2. 构建WHERE条件
|
||||
where_clause = []
|
||||
params = []
|
||||
|
||||
# 关键调整2:传IP时,强制指定da.client_ip的Collation
|
||||
if client_ip is not None:
|
||||
where_clause.append("da.client_ip COLLATE utf8mb4_general_ci = %s COLLATE utf8mb4_general_ci")
|
||||
params.append(client_ip)
|
||||
if start_date:
|
||||
where_clause.append("DATE(da.created_at) >= %s")
|
||||
params.append(start_date.strftime("%Y-%m-%d"))
|
||||
if end_date:
|
||||
where_clause.append("DATE(da.created_at) <= %s")
|
||||
params.append(end_date.strftime("%Y-%m-%d"))
|
||||
|
||||
# 3. 统计总数:JOIN时强制统一Collation
|
||||
count_query = """
|
||||
SELECT COUNT(*) AS total
|
||||
FROM device_action da
|
||||
LEFT JOIN devices d
|
||||
ON da.client_ip COLLATE utf8mb4_general_ci = d.client_ip COLLATE utf8mb4_general_ci
|
||||
"""
|
||||
if where_clause:
|
||||
count_query += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_query, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 4. 分页查询:JOIN时强制统一Collation
|
||||
offset = (page - 1) * page_size
|
||||
list_query = """
|
||||
SELECT da.*, d.id AS device_id
|
||||
FROM device_action da
|
||||
LEFT JOIN devices d
|
||||
ON da.client_ip COLLATE utf8mb4_general_ci = d.client_ip COLLATE utf8mb4_general_ci
|
||||
"""
|
||||
if where_clause:
|
||||
list_query += " WHERE " + " AND ".join(where_clause)
|
||||
list_query += " ORDER BY da.created_at DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset])
|
||||
cursor.execute(list_query, params)
|
||||
history_list = cursor.fetchall()
|
||||
|
||||
# 后续格式化响应逻辑不变...
|
||||
formatted_history = []
|
||||
for item in history_list:
|
||||
formatted_item = {
|
||||
"id": item["id"],
|
||||
"device_id": item["device_id"], # 可能为None(IP无对应设备)
|
||||
"client_ip": item["client_ip"],
|
||||
"status": item["action"],
|
||||
"status_time": item["created_at"]
|
||||
}
|
||||
formatted_history.append(formatted_item)
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取设备上下线记录成功",
|
||||
data=DeviceStatusHistoryListResponse(
|
||||
total=total,
|
||||
history=[DeviceStatusHistoryResponse(**item) for item in formatted_history]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取设备上下线记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
# ------------------------------
|
||||
# 通过客户端IP设置设备is_need_handler为0接口
|
||||
# ------------------------------
|
||||
@router.post("/need-handler/reset", response_model=APIResponse, summary="解封客户端")
|
||||
@encrypt_response()
|
||||
async def reset_device_need_handler(
|
||||
client_ip: str = Query(..., description="目标设备的客户端IP地址(必填)")
|
||||
):
|
||||
try:
|
||||
from service.device_service import update_is_need_handler_by_client_ip
|
||||
success = update_is_need_handler_by_client_ip(
|
||||
client_ip=client_ip,
|
||||
is_need_handler=0 # 固定设置为0(不需要处理)
|
||||
)
|
||||
|
||||
if success:
|
||||
online_status = is_client_connected(client_ip)
|
||||
|
||||
# 如果设备在线,则发送消息给前端
|
||||
if online_status:
|
||||
# 调用 ws 发送一个消息给前端、告诉他已解锁
|
||||
unlock_msg = {
|
||||
"type": "unlock",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": client_ip
|
||||
}
|
||||
from ws.ws import send_message_to_client
|
||||
await send_message_to_client(client_ip, json.dumps(unlock_msg))
|
||||
|
||||
# 休眠 100 ms
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
frame_permit_msg = {
|
||||
"type": "frame",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": client_ip
|
||||
}
|
||||
await send_message_to_client(client_ip, json.dumps(frame_permit_msg))
|
||||
|
||||
# 更新设备在线状态为1
|
||||
update_online_status_by_ip(client_ip, 1)
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"设备已解封",
|
||||
data={
|
||||
"client_ip": client_ip,
|
||||
"is_need_handler": 0,
|
||||
"status_desc": "设备已解封"
|
||||
}
|
||||
)
|
||||
|
||||
# 捕获工具方法抛出的业务异常(如IP为空、设备不存在)
|
||||
except ValueError as e:
|
||||
# 业务异常返回400/404状态码(与现有接口异常规范一致)
|
||||
raise HTTPException(
|
||||
status_code=404 if "设备不存在" in str(e) else 400,
|
||||
detail=str(e)
|
||||
) from e
|
||||
|
||||
# 捕获数据库层面异常(如连接失败、SQL执行错误)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"设置is_need_handler失败:数据库操作异常 - {str(e)}") from e
|
||||
|
||||
# 捕获其他未知异常
|
||||
except Exception as e:
|
||||
raise Exception(f"设置is_need_handler失败:未知错误 - {str(e)}") from e
|
||||
|
||||
|
||||
|
||||
326
router/face_router.py
Normal file
326
router/face_router.py
Normal file
@ -0,0 +1,326 @@
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, Request
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.face_schema import (
|
||||
FaceCreateRequest,
|
||||
FaceResponse,
|
||||
FaceListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
from service.face_service import update_face_data
|
||||
from util.face_util import add_binary_data
|
||||
from service.file_service import save_source_file
|
||||
|
||||
router = APIRouter(prefix="/api/faces", tags=["人脸管理"])
|
||||
|
||||
|
||||
# 创建人脸记录
|
||||
@router.post("", response_model=APIResponse, summary="创建人脸记录")
|
||||
@encrypt_response()
|
||||
async def create_face(
|
||||
request: Request,
|
||||
name: str = Form(None, max_length=255, description="名称(可选)"),
|
||||
file: UploadFile = File(..., description="人脸文件(必传)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
face_create = FaceCreateRequest(name=name)
|
||||
client_ip = request.client.host if request.client else ""
|
||||
if not client_ip:
|
||||
raise HTTPException(status_code=400, detail="无法获取客户端IP")
|
||||
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
# 先读取文件内容
|
||||
file_content = await file.read()
|
||||
# 将文件指针重置到开头
|
||||
await file.seek(0)
|
||||
# 再保存文件
|
||||
path = save_source_file(file, "face")
|
||||
# 提取人脸特征
|
||||
detect_success, detect_result = add_binary_data(file_content)
|
||||
if not detect_success:
|
||||
raise HTTPException(status_code=400, detail=f"人脸检测失败:{detect_result}")
|
||||
eigenvalue = detect_result
|
||||
|
||||
# 插入数据库
|
||||
insert_query = """
|
||||
INSERT INTO face (name, eigenvalue, address)
|
||||
VALUES (%s, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_query, (face_create.name, str(eigenvalue), path))
|
||||
conn.commit()
|
||||
|
||||
# 查询新记录
|
||||
cursor.execute("""
|
||||
SELECT id, name, address, created_at, updated_at
|
||||
FROM face
|
||||
WHERE id = LAST_INSERT_ID()
|
||||
""")
|
||||
created_face = cursor.fetchone()
|
||||
if not created_face:
|
||||
raise HTTPException(status_code=500, detail="创建成功但无法获取记录")
|
||||
|
||||
|
||||
# TODO 重新加载人脸模型
|
||||
update_face_data()
|
||||
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"人脸记录创建成功(ID: {created_face['id']})",
|
||||
data=FaceResponse(**created_face)
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"创建失败: {str(e)}") from e
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
finally:
|
||||
await file.close()
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 获取单个人脸记录
|
||||
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
|
||||
@encrypt_response()
|
||||
async def get_face(face_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
query = """
|
||||
SELECT id, name, address, created_at, updated_at
|
||||
FROM face
|
||||
WHERE id = %s
|
||||
"""
|
||||
cursor.execute(query, (face_id,))
|
||||
face = cursor.fetchone()
|
||||
|
||||
if not face:
|
||||
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="查询成功",
|
||||
data=FaceResponse(**face)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 获取人脸列表
|
||||
@router.get("", response_model=APIResponse, summary="获取人脸列表(分页+筛选)")
|
||||
@encrypt_response()
|
||||
async def get_face_list(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
name: str = Query(None),
|
||||
has_eigenvalue: bool = Query(None)
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
where_clause = []
|
||||
params = []
|
||||
if name:
|
||||
where_clause.append("name LIKE %s")
|
||||
params.append(f"%{name}%")
|
||||
if has_eigenvalue is not None:
|
||||
where_clause.append("eigenvalue IS NOT NULL" if has_eigenvalue else "eigenvalue IS NULL")
|
||||
|
||||
# 总记录数
|
||||
count_query = "SELECT COUNT(*) AS total FROM face"
|
||||
if where_clause:
|
||||
count_query += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_query, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 列表数据
|
||||
offset = (page - 1) * page_size
|
||||
list_query = """
|
||||
SELECT id, name, address, created_at, updated_at
|
||||
FROM face
|
||||
"""
|
||||
if where_clause:
|
||||
list_query += " WHERE " + " AND ".join(where_clause)
|
||||
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(list_query, params)
|
||||
face_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"获取成功(共{total}条)",
|
||||
data=FaceListResponse(
|
||||
total=total,
|
||||
faces=[FaceResponse(**face) for face in face_list]
|
||||
)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 删除人脸记录
|
||||
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
|
||||
@encrypt_response()
|
||||
async def delete_face(face_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
|
||||
exist_face = cursor.fetchone()
|
||||
if not exist_face:
|
||||
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
|
||||
old_db_path = exist_face["address"]
|
||||
|
||||
cursor.execute("DELETE FROM face WHERE id = %s", (face_id,))
|
||||
conn.commit()
|
||||
|
||||
# 删除图片
|
||||
if old_db_path:
|
||||
old_abs_path = Path(old_db_path).resolve()
|
||||
if old_abs_path.exists():
|
||||
try:
|
||||
old_abs_path.unlink()
|
||||
print(f"[FaceRouter] 已删除图片:{old_abs_path}")
|
||||
extra_msg = "(已同步删除图片)"
|
||||
except Exception as e:
|
||||
print(f"[FaceRouter] 删除图片失败:{str(e)}")
|
||||
extra_msg = "(图片删除失败)"
|
||||
else:
|
||||
extra_msg = "(图片不存在)"
|
||||
else:
|
||||
extra_msg = "(无关联图片)"
|
||||
|
||||
|
||||
# TODO 重新加载人脸模型
|
||||
update_face_data()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"ID为 {face_id} 的记录删除成功 {extra_msg}",
|
||||
data=None
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
@router.post("/batch-import", response_model=APIResponse, summary="批量导入文件夹下的人脸图片")
|
||||
# @encrypt_response()
|
||||
async def batch_import_faces(
|
||||
folder_path: str = Form(..., description="人脸图片所在的**服务器本地文件夹路径**")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
success_count = 0 # 成功导入数量
|
||||
fail_list = [] # 失败记录(包含文件名、错误原因)
|
||||
try:
|
||||
# 1. 验证文件夹有效性
|
||||
folder = Path(folder_path)
|
||||
if not folder.exists() or not folder.is_dir():
|
||||
raise HTTPException(status_code=400, detail=f"文件夹 {folder_path} 不存在或不是有效目录")
|
||||
|
||||
# 2. 定义支持的图片格式
|
||||
supported_extensions = {".png", ".jpg", ".jpeg", ".webp"}
|
||||
|
||||
# 3. 数据库连接初始化
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 4. 遍历文件夹内所有文件
|
||||
for file_path in folder.iterdir():
|
||||
if file_path.is_file() and file_path.suffix.lower() in supported_extensions:
|
||||
file_name = file_path.stem # 提取文件名(不含后缀)作为 `name`
|
||||
try:
|
||||
# 4.1 读取文件二进制内容
|
||||
with open(file_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
|
||||
# 4.2 构造模拟的 UploadFile 对象(用于兼容 `save_source_file`)
|
||||
mock_file = UploadFile(
|
||||
filename=file_path.name,
|
||||
file=BytesIO(file_content)
|
||||
)
|
||||
|
||||
# 4.3 保存文件到指定目录
|
||||
saved_path = save_source_file(mock_file, "face")
|
||||
|
||||
# 4.4 提取人脸特征
|
||||
detect_success, detect_result = add_binary_data(file_content)
|
||||
if not detect_success:
|
||||
fail_list.append({
|
||||
"name": file_name,
|
||||
"file_path": str(file_path),
|
||||
"error": f"人脸检测失败:{detect_result}"
|
||||
})
|
||||
continue # 跳过当前文件,处理下一个
|
||||
eigenvalue = detect_result
|
||||
|
||||
# 4.5 插入数据库
|
||||
insert_sql = """
|
||||
INSERT INTO face (name, eigenvalue, address)
|
||||
VALUES (%s, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_sql, (file_name, str(eigenvalue), saved_path))
|
||||
conn.commit() # 提交当前文件的插入操作
|
||||
|
||||
success_count += 1
|
||||
|
||||
except Exception as e:
|
||||
# 捕获单文件处理的异常,记录后继续处理其他文件
|
||||
fail_list.append({
|
||||
"name": file_name,
|
||||
"file_path": str(file_path),
|
||||
"error": str(e)
|
||||
})
|
||||
if conn:
|
||||
conn.rollback() # 回滚当前失败文件的插入
|
||||
|
||||
# 5. 重新加载人脸模型(确保新增数据生效)
|
||||
update_face_data()
|
||||
|
||||
# 6. 构造返回结果
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"批量导入完成,成功 {success_count} 条,失败 {len(fail_list)} 条",
|
||||
data={
|
||||
"success_count": success_count,
|
||||
"fail_details": fail_list
|
||||
}
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"数据库操作失败: {str(e)}") from e
|
||||
except HTTPException:
|
||||
raise # 直接抛出400等由业务逻辑触发的HTTP异常
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
31
router/file_router.py
Normal file
31
router/file_router.py
Normal file
@ -0,0 +1,31 @@
|
||||
import os
|
||||
from fastapi import FastAPI, HTTPException, Path, APIRouter
|
||||
from fastapi.responses import FileResponse
|
||||
from service.file_service import UPLOAD_ROOT
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/file",
|
||||
tags=["文件管理"]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/download/{relative_path:path}", summary="下载文件")
|
||||
async def download_file(
|
||||
relative_path: str = Path(..., description="文件的相对路径")
|
||||
):
|
||||
file_path = os.path.abspath(os.path.join(UPLOAD_ROOT, relative_path))
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(status_code=404, detail=f"文件不存在: {file_path}")
|
||||
|
||||
if not os.path.isfile(file_path):
|
||||
raise HTTPException(status_code=400, detail="路径指向的不是文件")
|
||||
|
||||
if not file_path.startswith(os.path.abspath(UPLOAD_ROOT)):
|
||||
raise HTTPException(status_code=403, detail="无权访问该文件")
|
||||
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=os.path.basename(file_path),
|
||||
media_type="application/octet-stream"
|
||||
)
|
||||
269
router/model_router.py
Normal file
269
router/model_router.py
Normal file
@ -0,0 +1,269 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from service.file_service import save_source_file
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.model_schema import (
|
||||
ModelResponse,
|
||||
ModelListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
from service.model_service import ALLOWED_MODEL_EXT, MAX_MODEL_SIZE, load_yolo_model
|
||||
|
||||
router = APIRouter(prefix="/api/models", tags=["模型管理"])
|
||||
|
||||
|
||||
# 上传模型
|
||||
@router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)")
|
||||
@encrypt_response()
|
||||
async def upload_model(
|
||||
name: str = Form(..., description="模型名称"),
|
||||
description: str = Form(None, description="模型描述"),
|
||||
file: UploadFile = File(..., description=f"YOLO模型文件(.pt、最大{MAX_MODEL_SIZE // 1024 // 1024}MB)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 校验文件格式
|
||||
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else ""
|
||||
if file_ext not in ALLOWED_MODEL_EXT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"仅支持{ALLOWED_MODEL_EXT}格式、当前:{file_ext}"
|
||||
)
|
||||
|
||||
# 校验文件大小
|
||||
if file.size > MAX_MODEL_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB、当前{file.size // 1024 // 1024}MB"
|
||||
)
|
||||
# 保存文件
|
||||
file_path = save_source_file(file, "model")
|
||||
|
||||
# 数据库操作
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
insert_sql = """
|
||||
INSERT INTO model (name, path, is_default, description, file_size)
|
||||
VALUES (%s, %s, 0, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_sql, (name, file_path, description, file.size))
|
||||
conn.commit()
|
||||
|
||||
# 获取新增记录
|
||||
cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()")
|
||||
new_model = cursor.fetchone()
|
||||
if not new_model:
|
||||
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"模型上传成功",
|
||||
data=ModelResponse(**new_model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e
|
||||
finally:
|
||||
await file.close()
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 获取模型列表
|
||||
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
|
||||
@encrypt_response()
|
||||
async def get_model_list(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
name: str = Query(None),
|
||||
is_default: bool = Query(None)
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
where_clause = []
|
||||
params = []
|
||||
if name:
|
||||
where_clause.append("name LIKE %s")
|
||||
params.append(f"%{name}%")
|
||||
if is_default is not None:
|
||||
where_clause.append("is_default = %s")
|
||||
params.append(1 if is_default else 0)
|
||||
|
||||
# 总记录数
|
||||
count_sql = "SELECT COUNT(*) AS total FROM model"
|
||||
if where_clause:
|
||||
count_sql += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_sql, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 分页数据
|
||||
offset = (page - 1) * page_size
|
||||
list_sql = "SELECT * FROM model"
|
||||
if where_clause:
|
||||
list_sql += " WHERE " + " AND ".join(where_clause)
|
||||
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(list_sql, params)
|
||||
model_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"获取成功!",
|
||||
data=ModelListResponse(
|
||||
total=total,
|
||||
models=[ModelResponse(**model) for model in model_list]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 更换默认模型
|
||||
@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型")
|
||||
@encrypt_response()
|
||||
async def set_default_model(
|
||||
model_id: int
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
conn.autocommit = False
|
||||
|
||||
# 校验目标模型是否存在
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
target_model = cursor.fetchone()
|
||||
if not target_model:
|
||||
raise HTTPException(status_code=404, detail=f"目标模型不存在!")
|
||||
|
||||
# 检查是否已为默认模型
|
||||
if target_model["is_default"]:
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"已是默认模型、无需更换",
|
||||
data=ModelResponse(**target_model)
|
||||
)
|
||||
|
||||
# 数据库事务:更新默认模型状态
|
||||
try:
|
||||
cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP")
|
||||
cursor.execute(
|
||||
"UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s",
|
||||
(model_id,)
|
||||
)
|
||||
conn.commit()
|
||||
except MySQLError as e:
|
||||
conn.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"更新默认模型状态失败(已回滚):{str(e)}"
|
||||
) from e
|
||||
|
||||
# 更新模型
|
||||
load_yolo_model()
|
||||
# 返回成功响应
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"更换成功",
|
||||
data=None
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
if conn:
|
||||
conn.autocommit = True
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 路由文件(如 model_router.py)中的删除接口
|
||||
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
|
||||
@encrypt_response()
|
||||
async def delete_model(model_id: int):
|
||||
# 1. 正确导入 model_service 中的全局变量(关键修复:变量名匹配)
|
||||
from service.model_service import (
|
||||
current_yolo_model,
|
||||
current_model_absolute_path,
|
||||
load_yolo_model # 用于删除后重新加载模型(可选)
|
||||
)
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 2. 查询待删除模型信息
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
exist_model = cursor.fetchone()
|
||||
if not exist_model:
|
||||
raise HTTPException(status_code=404, detail=f"模型不存在!")
|
||||
|
||||
# 3. 关键判断:①默认模型不可删 ②正在使用的模型不可删
|
||||
if exist_model["is_default"]:
|
||||
raise HTTPException(status_code=400, detail="默认模型不可删除!")
|
||||
|
||||
# 计算待删除模型的绝对路径(与 model_service 逻辑一致)
|
||||
from service.file_service import get_absolute_path
|
||||
del_model_abs_path = get_absolute_path(exist_model["path"])
|
||||
|
||||
# 判断是否正在使用(对比 current_model_absolute_path)
|
||||
if current_model_absolute_path and del_model_abs_path == current_model_absolute_path:
|
||||
raise HTTPException(status_code=400, detail="该模型正在使用中,禁止删除!")
|
||||
|
||||
# 4. 先删除数据库记录(避免文件删除失败导致数据不一致)
|
||||
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
|
||||
conn.commit()
|
||||
|
||||
# 5. 再删除本地文件(捕获文件删除异常,不影响数据库删除结果)
|
||||
extra_msg = ""
|
||||
try:
|
||||
if os.path.exists(del_model_abs_path):
|
||||
os.remove(del_model_abs_path) # 或用 Path(del_model_abs_path).unlink()
|
||||
extra_msg = "(本地文件已同步删除)"
|
||||
else:
|
||||
extra_msg = "(本地文件不存在,无需删除)"
|
||||
except Exception as e:
|
||||
extra_msg = f"(本地文件删除失败:{str(e)})"
|
||||
|
||||
# 6. 若删除后当前模型为空(极端情况),重新加载默认模型(可选优化)
|
||||
if current_yolo_model is None:
|
||||
try:
|
||||
load_yolo_model()
|
||||
print(f"[模型删除后] 重新加载默认模型成功")
|
||||
except Exception as e:
|
||||
print(f"[模型删除后] 重新加载默认模型失败:{str(e)}")
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"模型删除成功!",
|
||||
data=None
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
306
router/sensitive_router.py
Normal file
306
router/sensitive_router.py
Normal file
@ -0,0 +1,306 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, File, UploadFile
|
||||
from mysql.connector import Error as MySQLError
|
||||
from typing import Optional
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.sensitive_schema import (
|
||||
SensitiveCreateRequest,
|
||||
SensitiveResponse,
|
||||
SensitiveListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
from middle.auth_middleware import get_current_user
|
||||
from schema.user_schema import UserResponse
|
||||
from service.ocr_service import set_forbidden_words
|
||||
from service.sensitive_service import get_all_sensitive_words
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/sensitives",
|
||||
tags=["敏感信息管理"]
|
||||
)
|
||||
|
||||
|
||||
# 创建敏感信息记录
|
||||
@router.post("", response_model=APIResponse, summary="创建敏感信息记录")
|
||||
@encrypt_response()
|
||||
async def create_sensitive(
|
||||
sensitive: SensitiveCreateRequest,
|
||||
current_user: UserResponse = Depends(get_current_user)
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 插入新敏感信息记录到数据库(不包含ID、由数据库自动生成)
|
||||
insert_query = """
|
||||
INSERT INTO sensitives (name, created_at, updated_at)
|
||||
VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
"""
|
||||
cursor.execute(insert_query, (sensitive.name,))
|
||||
conn.commit()
|
||||
|
||||
# 获取刚插入记录的ID(使用LAST_INSERT_ID()函数)
|
||||
new_id = cursor.lastrowid
|
||||
|
||||
# 查询刚创建的记录并返回
|
||||
select_query = "SELECT * FROM sensitives WHERE id = %s"
|
||||
cursor.execute(select_query, (new_id,))
|
||||
created_sensitive = cursor.fetchone()
|
||||
|
||||
# 重新加载最新的敏感词
|
||||
set_forbidden_words(get_all_sensitive_words())
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="敏感信息记录创建成功",
|
||||
data=SensitiveResponse(**created_sensitive)
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"创建敏感信息记录失败: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 获取敏感信息分页列表
|
||||
@router.get("", response_model=APIResponse, summary="获取敏感信息分页列表(支持关键词搜索)")
|
||||
@encrypt_response()
|
||||
async def get_sensitive_list(
|
||||
page: int = Query(1, ge=1, description="页码(默认1、最小1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数(默认10、1-100)"),
|
||||
name: Optional[str] = Query(None, description="敏感词关键词搜索(模糊匹配)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 构建查询条件(支持关键词搜索)
|
||||
where_clause = []
|
||||
params = []
|
||||
if name:
|
||||
where_clause.append("name LIKE %s")
|
||||
params.append(f"%{name}%") # 模糊匹配关键词
|
||||
|
||||
# 2. 查询总记录数(用于分页计算)
|
||||
count_sql = "SELECT COUNT(*) AS total FROM sensitives"
|
||||
if where_clause:
|
||||
count_sql += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_sql, params.copy()) # 复制参数列表、避免后续污染
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 3. 计算分页偏移量
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 4. 分页查询敏感词数据(按更新时间倒序、最新的在前)
|
||||
list_sql = "SELECT * FROM sensitives"
|
||||
if where_clause:
|
||||
list_sql += " WHERE " + " AND ".join(where_clause)
|
||||
# 排序+分页(LIMIT 条数 OFFSET 偏移量)
|
||||
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
|
||||
# 补充分页参数(page_size和offset)
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(list_sql, params)
|
||||
sensitive_list = cursor.fetchall()
|
||||
|
||||
# 5. 构造分页响应数据
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"敏感信息列表查询成功(共{total}条记录、当前第{page}页)",
|
||||
data=SensitiveListResponse(
|
||||
total=total,
|
||||
sensitives=[SensitiveResponse(**item) for item in sensitive_list]
|
||||
)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"查询敏感信息列表失败: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 删除敏感信息记录
|
||||
@router.delete("/{sensitive_id}", response_model=APIResponse, summary="删除敏感信息记录")
|
||||
@encrypt_response()
|
||||
async def delete_sensitive(
|
||||
sensitive_id: int,
|
||||
current_user: UserResponse = Depends(get_current_user) # 需登录认证
|
||||
):
|
||||
"""
|
||||
删除敏感信息记录:
|
||||
- 需登录认证
|
||||
- 根据ID删除敏感信息记录
|
||||
- 返回删除成功信息
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 检查记录是否存在
|
||||
check_query = "SELECT id FROM sensitives WHERE id = %s"
|
||||
cursor.execute(check_query, (sensitive_id,))
|
||||
existing_sensitive = cursor.fetchone()
|
||||
if not existing_sensitive:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ID为 {sensitive_id} 的敏感信息记录不存在"
|
||||
)
|
||||
|
||||
# 2. 执行删除操作
|
||||
delete_query = "DELETE FROM sensitives WHERE id = %s"
|
||||
cursor.execute(delete_query, (sensitive_id,))
|
||||
conn.commit()
|
||||
|
||||
# 重新加载最新的敏感词
|
||||
set_forbidden_words(get_all_sensitive_words())
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"ID为 {sensitive_id} 的敏感信息记录删除成功",
|
||||
data=None
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"删除敏感信息记录失败: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 批量导入敏感信息(从txt文件)
|
||||
@router.post("/batch-import", response_model=APIResponse, summary="批量导入敏感信息(从txt文件)")
|
||||
@encrypt_response()
|
||||
async def batch_import_sensitives(
|
||||
file: UploadFile = File(..., description="包含敏感词的txt文件,每行一个敏感词"),
|
||||
# current_user: UserResponse = Depends(get_current_user) # 添加认证依赖
|
||||
):
|
||||
"""
|
||||
批量导入敏感信息:
|
||||
- 需登录认证
|
||||
- 接收txt文件,文件中每行一个敏感词
|
||||
- 批量插入到数据库中(仅插入不存在的敏感词)
|
||||
- 返回导入结果统计
|
||||
"""
|
||||
# 检查文件类型
|
||||
filename = file.filename or ""
|
||||
if not filename.lower().endswith(".txt"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"请上传txt格式的文件,当前文件格式: {filename.split('.')[-1] if '.' in filename else '未知'}"
|
||||
)
|
||||
|
||||
# 检查文件大小
|
||||
file_size = await file.read(1) # 读取1字节获取文件信息
|
||||
await file.seek(0) # 重置文件指针
|
||||
if not file_size: # 文件为空
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="上传的文件为空,请提供有效的敏感词文件"
|
||||
)
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 读取文件内容
|
||||
contents = await file.read()
|
||||
# 按行分割内容,处理不同操作系统的换行符
|
||||
lines = contents.decode("utf-8", errors="replace").splitlines()
|
||||
|
||||
# 过滤空行和仅含空白字符的行
|
||||
sensitive_words = [line.strip() for line in lines if line.strip()]
|
||||
|
||||
if not sensitive_words:
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="文件中没有有效的敏感词",
|
||||
data={"imported": 0, "total": 0}
|
||||
)
|
||||
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 先查询数据库中已存在的敏感词
|
||||
query = "SELECT name FROM sensitives WHERE name IN (%s)"
|
||||
# 处理参数,根据敏感词数量生成占位符
|
||||
placeholders = ', '.join(['%s'] * len(sensitive_words))
|
||||
cursor.execute(query % placeholders, sensitive_words)
|
||||
existing_words = {row['name'] for row in cursor.fetchall()}
|
||||
|
||||
# 过滤掉已存在的敏感词
|
||||
new_words = [word for word in sensitive_words if word not in existing_words]
|
||||
|
||||
if not new_words:
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="所有敏感词均已存在于数据库中",
|
||||
data={
|
||||
"total": len(sensitive_words),
|
||||
"imported": 0,
|
||||
"duplicates": len(sensitive_words),
|
||||
"message": f"共处理{len(sensitive_words)}个敏感词,全部已存在,未导入任何新敏感词"
|
||||
}
|
||||
)
|
||||
|
||||
# 批量插入新的敏感词
|
||||
insert_query = """
|
||||
INSERT INTO sensitives (name, created_at, updated_at)
|
||||
VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
"""
|
||||
|
||||
# 准备参数列表
|
||||
params = [(word,) for word in new_words]
|
||||
|
||||
# 执行批量插入
|
||||
cursor.executemany(insert_query, params)
|
||||
conn.commit()
|
||||
|
||||
# 重新加载最新的敏感词
|
||||
set_forbidden_words(get_all_sensitive_words())
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"敏感词批量导入成功",
|
||||
data={
|
||||
"total": len(sensitive_words),
|
||||
"imported": len(new_words),
|
||||
"duplicates": len(sensitive_words) - len(new_words),
|
||||
"message": f"共处理{len(sensitive_words)}个敏感词,成功导入{len(new_words)}个,{len(sensitive_words) - len(new_words)}个已存在"
|
||||
}
|
||||
)
|
||||
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="文件编码格式错误,请使用UTF-8编码的txt文件"
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"批量导入敏感词失败: {str(e)}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"处理文件时发生错误: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
await file.close()
|
||||
db.close_connection(conn, cursor)
|
||||
246
router/user_router.py
Normal file
246
router/user_router.py
Normal file
@ -0,0 +1,246 @@
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from middle.auth_middleware import (
|
||||
get_password_hash,
|
||||
verify_password,
|
||||
create_access_token,
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
get_current_user
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
from schema.user_schema import UserRegisterRequest, UserLoginRequest, UserResponse
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/users",
|
||||
tags=["用户管理"]
|
||||
)
|
||||
|
||||
|
||||
# 用户注册接口
|
||||
@router.post("/register", response_model=APIResponse, summary="用户注册")
|
||||
@encrypt_response()
|
||||
async def user_register(request: UserRegisterRequest):
|
||||
"""
|
||||
用户注册:
|
||||
- 校验用户名是否已存在
|
||||
- 加密密码后插入数据库
|
||||
- 返回注册成功信息
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 检查用户名是否已存在(唯一索引)
|
||||
check_query = "SELECT username FROM users WHERE username = %s"
|
||||
cursor.execute(check_query, (request.username,))
|
||||
existing_user = cursor.fetchone()
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"用户名 '{request.username}' 已存在、请更换其他用户名"
|
||||
)
|
||||
|
||||
# 2. 加密密码
|
||||
hashed_password = get_password_hash(request.password)
|
||||
|
||||
# 3. 插入新用户到数据库
|
||||
insert_query = """
|
||||
INSERT INTO users (username, password)
|
||||
VALUES (%s, %s)
|
||||
"""
|
||||
cursor.execute(insert_query, (request.username, hashed_password))
|
||||
conn.commit() # 提交事务
|
||||
|
||||
# 4. 返回注册成功响应
|
||||
return APIResponse(
|
||||
code=200, # 200 表示资源创建成功
|
||||
message=f"用户 '{request.username}' 注册成功",
|
||||
data=None
|
||||
)
|
||||
except MySQLError as e:
|
||||
conn.rollback() # 数据库错误时回滚事务
|
||||
raise Exception(f"注册失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 用户登录接口
|
||||
@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token)")
|
||||
@encrypt_response()
|
||||
async def user_login(request: UserLoginRequest):
|
||||
"""
|
||||
用户登录:
|
||||
- 校验用户名是否存在
|
||||
- 校验密码是否正确
|
||||
- 生成 JWT Token 并返回
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 修复: SQL查询添加 created_at 和 updated_at 字段
|
||||
query = """
|
||||
SELECT id, username, password, created_at, updated_at
|
||||
FROM users
|
||||
WHERE username = %s
|
||||
"""
|
||||
cursor.execute(query, (request.username,))
|
||||
user = cursor.fetchone()
|
||||
|
||||
# 2. 校验用户名和密码
|
||||
if not user or not verify_password(request.password, user["password"]):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="用户名或密码错误",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# 3. 生成 Token(过期时间从配置读取)
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(
|
||||
data={"sub": user["username"]},
|
||||
expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
# 4. 返回 Token 和用户基本信息
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="登录成功",
|
||||
data={
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"user": UserResponse(
|
||||
id=user["id"],
|
||||
username=user["username"],
|
||||
created_at=user.get("created_at"),
|
||||
updated_at=user.get("updated_at")
|
||||
)
|
||||
}
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"登录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 获取当前登录用户信息(需认证)
|
||||
@router.get("/me", response_model=APIResponse, summary="获取当前用户信息")
|
||||
@encrypt_response()
|
||||
async def get_current_user_info(
|
||||
current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件
|
||||
):
|
||||
"""
|
||||
获取当前登录用户信息:
|
||||
- 需在请求头携带 Token(格式: Bearer <token>)
|
||||
- 认证通过后返回用户信息
|
||||
"""
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取用户信息成功",
|
||||
data=current_user
|
||||
)
|
||||
|
||||
|
||||
# 获取用户列表(仅需登录权限)
|
||||
@router.get("/list", response_model=APIResponse, summary="获取用户列表")
|
||||
@encrypt_response()
|
||||
async def get_user_list(
|
||||
page: int = Query(1, ge=1, description="页码、从1开始"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
|
||||
username: Optional[str] = Query(None, description="用户名模糊搜索"),
|
||||
current_user: UserResponse = Depends(get_current_user) # 仅需登录即可访问(移除管理员校验)
|
||||
):
|
||||
"""
|
||||
获取用户列表:
|
||||
- 需登录权限(请求头携带 Token: Bearer <token>)
|
||||
- 支持分页查询(page=页码、page_size=每页条数)
|
||||
- 支持用户名模糊搜索(如输入"test"可匹配"test123"、"admin_test"等)
|
||||
- 仅返回用户ID、用户名、创建时间、更新时间(不包含密码等敏感信息)
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 计算分页偏移量(page从1开始、偏移量=(页码-1)*每页条数)
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 基础查询(仅查非敏感字段)
|
||||
base_query = """
|
||||
SELECT id, username, created_at, updated_at
|
||||
FROM users
|
||||
"""
|
||||
# 总条数查询(用于分页计算)
|
||||
count_query = "SELECT COUNT(*) as total FROM users"
|
||||
|
||||
# 条件拼接(支持用户名模糊搜索)
|
||||
conditions = []
|
||||
params = []
|
||||
if username:
|
||||
conditions.append("username LIKE %s")
|
||||
params.append(f"%{username}%") # 模糊匹配:%表示任意字符
|
||||
|
||||
# 构建最终查询语句
|
||||
if conditions:
|
||||
where_clause = " WHERE " + " AND ".join(conditions)
|
||||
final_query = f"{base_query}{where_clause} LIMIT %s OFFSET %s"
|
||||
final_count_query = f"{count_query}{where_clause}"
|
||||
params.extend([page_size, offset]) # 追加分页参数
|
||||
else:
|
||||
final_query = f"{base_query} LIMIT %s OFFSET %s"
|
||||
final_count_query = count_query
|
||||
params = [page_size, offset]
|
||||
|
||||
# 1. 查询用户列表数据
|
||||
cursor.execute(final_query, params)
|
||||
users = cursor.fetchall()
|
||||
|
||||
# 2. 查询总条数(用于计算总页数)
|
||||
count_params = [f"%{username}%"] if username else []
|
||||
cursor.execute(final_count_query, count_params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 3. 转换为UserResponse模型(确保字段匹配)
|
||||
user_list = [
|
||||
UserResponse(
|
||||
id=user["id"],
|
||||
username=user["username"],
|
||||
created_at=user["created_at"],
|
||||
updated_at=user["updated_at"]
|
||||
)
|
||||
for user in users
|
||||
]
|
||||
|
||||
# 4. 计算总页数(向上取整、如11条数据每页10条=2页)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# 返回结果(包含列表和分页信息)
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取用户列表成功",
|
||||
data={
|
||||
"users": user_list,
|
||||
"pagination": {
|
||||
"page": page, # 当前页码
|
||||
"page_size": page_size, # 每页条数
|
||||
"total": total, # 总数据量
|
||||
"total_pages": total_pages # 总页数
|
||||
}
|
||||
}
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取用户列表失败: {str(e)}") from e
|
||||
finally:
|
||||
# 无论成功失败、都关闭数据库连接
|
||||
db.close_connection(conn, cursor)
|
||||
Reference in New Issue
Block a user