模型添加置信度设置,敏感词分页
This commit is contained in:
13
core/yolo.py
13
core/yolo.py
@ -1,5 +1,5 @@
|
|||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
from service.model_service import get_current_yolo_model # 带版本校验的模型获取
|
from service.model_service import get_current_yolo_model, get_current_conf_threshold # 新增置信度获取函数
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path=None):
|
def load_model(model_path=None):
|
||||||
@ -15,8 +15,8 @@ def load_model(model_path=None):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def detect(frame, conf_threshold=0.7):
|
def detect(frame):
|
||||||
"""执行目标检测(仅模型版本变化时重新加载,平时复用缓存)"""
|
"""执行目标检测(使用动态置信度,仅模型版本变化时重新加载)"""
|
||||||
# 获取模型(内部已做版本校验,未变化则直接返回缓存)
|
# 获取模型(内部已做版本校验,未变化则直接返回缓存)
|
||||||
current_model = load_model()
|
current_model = load_model()
|
||||||
if not current_model:
|
if not current_model:
|
||||||
@ -26,6 +26,8 @@ def detect(frame, conf_threshold=0.7):
|
|||||||
return (False, "无效输入帧")
|
return (False, "无效输入帧")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 获取动态置信度(从全局配置中读取)
|
||||||
|
conf_threshold = get_current_conf_threshold()
|
||||||
# 用当前模型执行检测(复用缓存,无额外加载耗时)
|
# 用当前模型执行检测(复用缓存,无额外加载耗时)
|
||||||
results = current_model(frame, conf=conf_threshold, verbose=False)
|
results = current_model(frame, conf=conf_threshold, verbose=False)
|
||||||
has_results = len(results[0].boxes) > 0 if results else False
|
has_results = len(results[0].boxes) > 0 if results else False
|
||||||
@ -43,11 +45,6 @@ def detect(frame, conf_threshold=0.7):
|
|||||||
class_name = current_model.names[cls] if hasattr(current_model, 'names') else f"类别{cls}"
|
class_name = current_model.names[cls] if hasattr(current_model, 'names') else f"类别{cls}"
|
||||||
result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox})")
|
result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox})")
|
||||||
|
|
||||||
# 打印当前使用的模型路径和版本(用于验证)
|
|
||||||
# model_path = getattr(current_model, "model_path", "未知路径")
|
|
||||||
# from service.model_service import _current_model_version
|
|
||||||
# print(f"[YOLO检测] 使用模型:{model_path}(版本:{_current_model_version[:10]}...)")
|
|
||||||
|
|
||||||
return (True, "; ".join(result_parts))
|
return (True, "; ".join(result_parts))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@ -7,24 +8,31 @@ from pydantic import BaseModel, Field
|
|||||||
# ------------------------------
|
# ------------------------------
|
||||||
class SensitiveCreateRequest(BaseModel):
|
class SensitiveCreateRequest(BaseModel):
|
||||||
"""创建敏感信息记录请求模型"""
|
"""创建敏感信息记录请求模型"""
|
||||||
# 移除了id字段、由数据库自动生成
|
name: str = Field(..., max_length=255, description="敏感词内容(必填)")
|
||||||
name: str = Field(None, max_length=255, description="名称")
|
|
||||||
|
|
||||||
|
|
||||||
class SensitiveUpdateRequest(BaseModel):
|
class SensitiveUpdateRequest(BaseModel):
|
||||||
"""更新敏感信息记录请求模型"""
|
"""更新敏感信息记录请求模型"""
|
||||||
name: str = Field(None, max_length=255, description="名称")
|
name: Optional[str] = Field(None, max_length=255, description="敏感词内容(可选修改)")
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 响应模型(后端返回数据)
|
# 响应模型(后端返回数据)
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
class SensitiveResponse(BaseModel):
|
class SensitiveResponse(BaseModel):
|
||||||
"""敏感信息记录响应模型"""
|
"""敏感信息单条记录响应模型"""
|
||||||
id: int = Field(..., description="主键ID") # 响应中仍然包含ID
|
id: int = Field(..., description="主键ID")
|
||||||
name: str = Field(None, description="名称")
|
name: str = Field(..., description="敏感词内容")
|
||||||
created_at: datetime = Field(..., description="记录创建时间")
|
created_at: datetime = Field(..., description="记录创建时间")
|
||||||
updated_at: datetime = Field(..., description="记录更新时间")
|
updated_at: datetime = Field(..., description="记录更新时间")
|
||||||
|
|
||||||
# 支持从数据库查询结果转换
|
# 支持从数据库查询结果(字典/对象)自动转换
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class SensitiveListResponse(BaseModel):
|
||||||
|
"""敏感信息分页列表响应模型(新增)"""
|
||||||
|
total: int = Field(..., description="敏感词总记录数")
|
||||||
|
sensitives: List[SensitiveResponse] = Field(..., description="当前页敏感词列表")
|
||||||
|
|
||||||
model_config = {"from_attributes": True}
|
model_config = {"from_attributes": True}
|
||||||
@ -31,15 +31,16 @@ DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep
|
|||||||
ALLOWED_MODEL_EXT = {"pt"}
|
ALLOWED_MODEL_EXT = {"pt"}
|
||||||
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
|
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
|
||||||
|
|
||||||
# 全局模型变量(带版本标识)
|
# 全局模型变量(带版本标识和置信度)
|
||||||
global _yolo_model, _current_model_version
|
global _yolo_model, _current_model_version, _current_conf_threshold
|
||||||
_yolo_model = None
|
_yolo_model = None
|
||||||
_current_model_version = None # 模型版本标识(用于检测模型是否变化)
|
_current_model_version = None # 模型版本标识
|
||||||
|
_current_conf_threshold = 0.8 # 默认置信度初始值
|
||||||
|
|
||||||
router = APIRouter(prefix="/models", tags=["模型管理"])
|
router = APIRouter(prefix="/models", tags=["模型管理"])
|
||||||
|
|
||||||
|
|
||||||
# 服务重启核心工具函数
|
# 服务重启核心工具函数(保持不变)
|
||||||
def restart_service():
|
def restart_service():
|
||||||
"""重启当前FastAPI服务进程"""
|
"""重启当前FastAPI服务进程"""
|
||||||
print("\n[服务重启] 检测到默认模型更换,开始清理资源并重启...")
|
print("\n[服务重启] 检测到默认模型更换,开始清理资源并重启...")
|
||||||
@ -87,7 +88,7 @@ def restart_service():
|
|||||||
raise HTTPException(status_code=500, detail=f"模型更换成功,但服务重启失败:{str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"模型更换成功,但服务重启失败:{str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
# 模型路径验证工具函数
|
# 模型路径验证工具函数(保持不变)
|
||||||
def get_valid_model_abs_path(relative_path: str) -> str:
|
def get_valid_model_abs_path(relative_path: str) -> str:
|
||||||
try:
|
try:
|
||||||
relative_path = relative_path.replace("/", os.sep)
|
relative_path = relative_path.replace("/", os.sep)
|
||||||
@ -139,7 +140,7 @@ def get_valid_model_abs_path(relative_path: str) -> str:
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
# 对外提供当前模型(带版本校验)
|
# 对外提供当前模型(带版本校验)(保持不变)
|
||||||
def get_current_yolo_model():
|
def get_current_yolo_model():
|
||||||
"""供检测模块获取当前最新默认模型(仅版本变化时重新加载)"""
|
"""供检测模块获取当前最新默认模型(仅版本变化时重新加载)"""
|
||||||
global _yolo_model, _current_model_version
|
global _yolo_model, _current_model_version
|
||||||
@ -155,21 +156,19 @@ def get_current_yolo_model():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 1. 计算当前默认模型的唯一版本标识
|
# 1. 计算当前默认模型的唯一版本标识
|
||||||
# (路径哈希 + 文件修改时间戳,确保模型变化时版本变化)
|
|
||||||
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
||||||
model_stat = os.stat(valid_abs_path)
|
model_stat = os.stat(valid_abs_path)
|
||||||
model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
|
model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
|
||||||
|
|
||||||
# 2. 版本未变化则复用已有模型(核心优化点)
|
# 2. 版本未变化则复用已有模型
|
||||||
if _yolo_model and _current_model_version == model_version:
|
if _yolo_model and _current_model_version == model_version:
|
||||||
# print(f"[get_current_yolo_model] 模型版本未变,复用缓存(版本:{_current_model_version[:10]}...)")
|
|
||||||
return _yolo_model
|
return _yolo_model
|
||||||
|
|
||||||
# 3. 版本变化时重新加载模型
|
# 3. 版本变化时重新加载模型
|
||||||
_yolo_model = load_yolo_model(valid_abs_path)
|
_yolo_model = load_yolo_model(valid_abs_path)
|
||||||
if _yolo_model:
|
if _yolo_model:
|
||||||
setattr(_yolo_model, "model_path", valid_abs_path)
|
setattr(_yolo_model, "model_path", valid_abs_path)
|
||||||
_current_model_version = model_version # 更新版本标识
|
_current_model_version = model_version
|
||||||
print(f"[get_current_yolo_model] 模型版本更新,重新加载(新版本:{model_version[:10]}...)")
|
print(f"[get_current_yolo_model] 模型版本更新,重新加载(新版本:{model_version[:10]}...)")
|
||||||
else:
|
else:
|
||||||
print(f"[get_current_yolo_model] 加载最新默认模型失败:{valid_abs_path}")
|
print(f"[get_current_yolo_model] 加载最新默认模型失败:{valid_abs_path}")
|
||||||
@ -182,7 +181,14 @@ def get_current_yolo_model():
|
|||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# 1. 上传模型
|
# 新增:获取当前置信度阈值
|
||||||
|
def get_current_conf_threshold():
|
||||||
|
"""供检测模块获取当前设置的置信度阈值"""
|
||||||
|
global _current_conf_threshold
|
||||||
|
return _current_conf_threshold
|
||||||
|
|
||||||
|
|
||||||
|
# 1. 上传模型(保持不变)
|
||||||
@router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)")
|
@router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)")
|
||||||
async def upload_model(
|
async def upload_model(
|
||||||
name: str = Form(..., description="模型名称"),
|
name: str = Form(..., description="模型名称"),
|
||||||
@ -255,7 +261,7 @@ async def upload_model(
|
|||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=201,
|
code=201,
|
||||||
message=f"模型上传成功!ID:{new_model['id']}",
|
message=f"模型上传成功!ID:{new_model['id']}",
|
||||||
data=ModelResponse(**new_model)
|
data=ModelResponse(** new_model)
|
||||||
)
|
)
|
||||||
|
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
@ -273,7 +279,7 @@ async def upload_model(
|
|||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# 2. 获取模型列表
|
# 2. 获取模型列表(保持不变)
|
||||||
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
|
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
|
||||||
async def get_model_list(
|
async def get_model_list(
|
||||||
page: int = Query(1, ge=1),
|
page: int = Query(1, ge=1),
|
||||||
@ -319,7 +325,7 @@ async def get_model_list(
|
|||||||
message=f"获取成功!共{total}条记录",
|
message=f"获取成功!共{total}条记录",
|
||||||
data=ModelListResponse(
|
data=ModelListResponse(
|
||||||
total=total,
|
total=total,
|
||||||
models=[ModelResponse(**model) for model in model_list]
|
models=[ModelResponse(** model) for model in model_list]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -329,7 +335,7 @@ async def get_model_list(
|
|||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# 3. 获取默认模型
|
# 3. 获取默认模型(保持不变)
|
||||||
@router.get("/default", response_model=APIResponse, summary="获取当前默认模型")
|
@router.get("/default", response_model=APIResponse, summary="获取当前默认模型")
|
||||||
async def get_default_model():
|
async def get_default_model():
|
||||||
conn = None
|
conn = None
|
||||||
@ -371,7 +377,7 @@ async def get_default_model():
|
|||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# 4. 获取单个模型详情
|
# 4. 获取单个模型详情(保持不变)
|
||||||
@router.get("/{model_id}", response_model=APIResponse, summary="获取单个模型详情")
|
@router.get("/{model_id}", response_model=APIResponse, summary="获取单个模型详情")
|
||||||
async def get_model(model_id: int):
|
async def get_model(model_id: int):
|
||||||
conn = None
|
conn = None
|
||||||
@ -392,7 +398,7 @@ async def get_model(model_id: int):
|
|||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=200,
|
code=200,
|
||||||
message=f"查询成功,但路径异常:{e.detail}",
|
message=f"查询成功,但路径异常:{e.detail}",
|
||||||
data=ModelResponse(**model)
|
data=ModelResponse(** model)
|
||||||
)
|
)
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
@ -400,14 +406,13 @@ async def get_model(model_id: int):
|
|||||||
message="查询成功",
|
message="查询成功",
|
||||||
data=ModelResponse(**model)
|
data=ModelResponse(**model)
|
||||||
)
|
)
|
||||||
|
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# 5. 更新模型信息
|
# 5. 更新模型信息(保持不变)
|
||||||
@router.put("/{model_id}", response_model=APIResponse, summary="更新模型信息")
|
@router.put("/{model_id}", response_model=APIResponse, summary="更新模型信息")
|
||||||
async def update_model(model_id: int, model_update: ModelUpdateRequest):
|
async def update_model(model_id: int, model_update: ModelUpdateRequest):
|
||||||
conn = None
|
conn = None
|
||||||
@ -479,7 +484,7 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest):
|
|||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=200,
|
code=200,
|
||||||
message="模型更新成功",
|
message="模型更新成功",
|
||||||
data=ModelResponse(**updated_model)
|
data=ModelResponse(** updated_model)
|
||||||
)
|
)
|
||||||
|
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
@ -490,9 +495,12 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest):
|
|||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# 5.1 更换默认模型(自动重启服务)
|
# 5.1 更换默认模型(添加置信度参数)
|
||||||
@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型(自动重启服务)")
|
@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型(自动重启服务)")
|
||||||
async def set_default_model(model_id: int):
|
async def set_default_model(
|
||||||
|
model_id: int,
|
||||||
|
conf_threshold: float = Query(0.8, ge=0.01, le=0.99, description="模型检测置信度阈值(0.01-0.99)")
|
||||||
|
):
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
try:
|
try:
|
||||||
@ -551,10 +559,11 @@ async def set_default_model(model_id: int):
|
|||||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||||
updated_model = cursor.fetchone()
|
updated_model = cursor.fetchone()
|
||||||
|
|
||||||
# 7. 重置版本标识(关键:确保下次检测加载新模型)
|
# 7. 重置版本标识和更新置信度
|
||||||
global _current_model_version
|
global _current_model_version, _current_conf_threshold
|
||||||
_current_model_version = None
|
_current_model_version = None
|
||||||
print(f"[更换默认模型] 已重置模型版本标识,下次检测将加载新模型")
|
_current_conf_threshold = conf_threshold # 保存动态置信度
|
||||||
|
print(f"[更换默认模型] 已重置模型版本标识,设置新置信度:{conf_threshold}")
|
||||||
|
|
||||||
# 8. 延迟重启服务
|
# 8. 延迟重启服务
|
||||||
print(f"[更换默认模型] 成功!将在1秒后重启服务以应用新模型(ID:{model_id})")
|
print(f"[更换默认模型] 成功!将在1秒后重启服务以应用新模型(ID:{model_id})")
|
||||||
@ -566,8 +575,8 @@ async def set_default_model(model_id: int):
|
|||||||
# 9. 返回成功响应
|
# 9. 返回成功响应
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=200,
|
code=200,
|
||||||
message=f"已成功更换默认模型(ID:{model_id})!服务将在1秒后自动重启以应用新模型",
|
message=f"已成功更换默认模型(ID:{model_id}),置信度:{conf_threshold}!服务将在1秒后自动重启以应用新模型",
|
||||||
data=ModelResponse(**updated_model)
|
data=ModelResponse(** updated_model)
|
||||||
)
|
)
|
||||||
|
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
@ -580,7 +589,7 @@ async def set_default_model(model_id: int):
|
|||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# 6. 删除模型
|
# 6. 删除模型(保持不变)
|
||||||
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
|
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
|
||||||
async def delete_model(model_id: int):
|
async def delete_model(model_id: int):
|
||||||
conn = None
|
conn = None
|
||||||
@ -639,7 +648,7 @@ async def delete_model(model_id: int):
|
|||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# 7. 下载模型文件
|
# 7. 下载模型文件(保持不变)
|
||||||
@router.get("/{model_id}/download", summary="下载模型文件")
|
@router.get("/{model_id}/download", summary="下载模型文件")
|
||||||
async def download_model(model_id: int):
|
async def download_model(model_id: int):
|
||||||
conn = None
|
conn = None
|
||||||
|
|||||||
@ -1,8 +1,14 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from mysql.connector import Error as MySQLError
|
from mysql.connector import Error as MySQLError
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from ds.db import db
|
from ds.db import db
|
||||||
from schema.sensitive_schema import SensitiveCreateRequest, SensitiveUpdateRequest, SensitiveResponse
|
from schema.sensitive_schema import (
|
||||||
|
SensitiveCreateRequest,
|
||||||
|
SensitiveUpdateRequest,
|
||||||
|
SensitiveResponse,
|
||||||
|
SensitiveListResponse # 导入新增的分页响应模型
|
||||||
|
)
|
||||||
from schema.response_schema import APIResponse
|
from schema.response_schema import APIResponse
|
||||||
from middle.auth_middleware import get_current_user
|
from middle.auth_middleware import get_current_user
|
||||||
from schema.user_schema import UserResponse
|
from schema.user_schema import UserResponse
|
||||||
@ -19,7 +25,9 @@ router = APIRouter(
|
|||||||
# ------------------------------
|
# ------------------------------
|
||||||
@router.post("", response_model=APIResponse, summary="创建敏感信息记录")
|
@router.post("", response_model=APIResponse, summary="创建敏感信息记录")
|
||||||
async def create_sensitive(
|
async def create_sensitive(
|
||||||
sensitive: SensitiveCreateRequest): # 添加了登录认证依赖
|
sensitive: SensitiveCreateRequest,
|
||||||
|
current_user: UserResponse = Depends(get_current_user) # 补充登录认证依赖(与其他接口保持一致)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
创建敏感信息记录:
|
创建敏感信息记录:
|
||||||
- 需登录认证
|
- 需登录认证
|
||||||
@ -34,8 +42,8 @@ async def create_sensitive(
|
|||||||
|
|
||||||
# 插入新敏感信息记录到数据库(不包含ID、由数据库自动生成)
|
# 插入新敏感信息记录到数据库(不包含ID、由数据库自动生成)
|
||||||
insert_query = """
|
insert_query = """
|
||||||
INSERT INTO sensitives (name)
|
INSERT INTO sensitives (name, created_at, updated_at)
|
||||||
VALUES (%s)
|
VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||||
"""
|
"""
|
||||||
cursor.execute(insert_query, (sensitive.name,))
|
cursor.execute(insert_query, (sensitive.name,))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
@ -56,12 +64,14 @@ async def create_sensitive(
|
|||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise Exception(f"创建敏感信息记录失败: {str(e)}") from e
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"创建敏感信息记录失败: {str(e)}"
|
||||||
|
) from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# 以下接口代码保持不变
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 2. 获取单个敏感信息记录
|
# 2. 获取单个敏感信息记录
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@ -98,21 +108,29 @@ async def get_sensitive(
|
|||||||
data=SensitiveResponse(**sensitive)
|
data=SensitiveResponse(**sensitive)
|
||||||
)
|
)
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
raise Exception(f"查询敏感信息记录失败: {str(e)}") from e
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"查询敏感信息记录失败: {str(e)}"
|
||||||
|
) from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 3. 获取所有敏感信息记录
|
# 3. 获取敏感信息分页列表(重构:支持分页+关键词搜索)
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@router.get("", response_model=APIResponse, summary="获取所有敏感信息记录")
|
@router.get("", response_model=APIResponse, summary="获取敏感信息分页列表(支持关键词搜索)")
|
||||||
async def get_all_sensitives():
|
async def get_sensitive_list(
|
||||||
|
page: int = Query(1, ge=1, description="页码(默认1,最小1)"),
|
||||||
|
page_size: int = Query(10, ge=1, le=100, description="每页条数(默认10,1-100)"),
|
||||||
|
name: Optional[str] = Query(None, description="敏感词关键词搜索(模糊匹配)"),
|
||||||
|
current_user: UserResponse = Depends(get_current_user) # 需登录认证
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
获取所有敏感信息记录:
|
获取敏感信息分页列表:
|
||||||
- 需登录认证
|
- 需登录认证
|
||||||
- 查询所有敏感信息记录(不需要分页)
|
- 支持分页(page/page_size)和敏感词关键词模糊搜索(name)
|
||||||
- 返回所有敏感信息列表
|
- 返回总记录数+当前页数据
|
||||||
"""
|
"""
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
@ -120,17 +138,49 @@ async def get_all_sensitives():
|
|||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
query = "SELECT * FROM sensitives ORDER BY id"
|
# 1. 构建查询条件(支持关键词搜索)
|
||||||
cursor.execute(query)
|
where_clause = []
|
||||||
sensitives = cursor.fetchall()
|
params = []
|
||||||
|
if name:
|
||||||
|
where_clause.append("name LIKE %s")
|
||||||
|
params.append(f"%{name}%") # 模糊匹配关键词
|
||||||
|
|
||||||
|
# 2. 查询总记录数(用于分页计算)
|
||||||
|
count_sql = "SELECT COUNT(*) AS total FROM sensitives"
|
||||||
|
if where_clause:
|
||||||
|
count_sql += " WHERE " + " AND ".join(where_clause)
|
||||||
|
cursor.execute(count_sql, params.copy()) # 复制参数列表,避免后续污染
|
||||||
|
total = cursor.fetchone()["total"]
|
||||||
|
|
||||||
|
# 3. 计算分页偏移量
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
# 4. 分页查询敏感词数据(按更新时间倒序,最新的在前)
|
||||||
|
list_sql = "SELECT * FROM sensitives"
|
||||||
|
if where_clause:
|
||||||
|
list_sql += " WHERE " + " AND ".join(where_clause)
|
||||||
|
# 排序+分页(LIMIT 条数 OFFSET 偏移量)
|
||||||
|
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
|
||||||
|
# 补充分页参数(page_size和offset)
|
||||||
|
params.extend([page_size, offset])
|
||||||
|
|
||||||
|
cursor.execute(list_sql, params)
|
||||||
|
sensitive_list = cursor.fetchall()
|
||||||
|
|
||||||
|
# 5. 构造分页响应数据
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=200,
|
code=200,
|
||||||
message="所有敏感信息记录查询成功",
|
message=f"敏感信息列表查询成功(共{total}条记录,当前第{page}页)",
|
||||||
data=[SensitiveResponse(**sensitive) for sensitive in sensitives]
|
data=SensitiveListResponse(
|
||||||
|
total=total,
|
||||||
|
sensitives=[SensitiveResponse(**item) for item in sensitive_list]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
raise Exception(f"查询所有敏感信息记录失败: {str(e)}") from e
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"查询敏感信息列表失败: {str(e)}"
|
||||||
|
) from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
@ -177,14 +227,16 @@ async def update_sensitive(
|
|||||||
if not update_fields:
|
if not update_fields:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail="至少需要提供一个字段进行更新"
|
detail="至少需要提供一个字段进行更新(如:name)"
|
||||||
)
|
)
|
||||||
|
|
||||||
params.append(sensitive_id) # WHERE条件的参数
|
# 补充更新时间和WHERE条件参数
|
||||||
|
update_fields.append("updated_at = CURRENT_TIMESTAMP")
|
||||||
|
params.append(sensitive_id)
|
||||||
|
|
||||||
update_query = f"""
|
update_query = f"""
|
||||||
UPDATE sensitives
|
UPDATE sensitives
|
||||||
SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP
|
SET {', '.join(update_fields)}
|
||||||
WHERE id = %s
|
WHERE id = %s
|
||||||
"""
|
"""
|
||||||
cursor.execute(update_query, params)
|
cursor.execute(update_query, params)
|
||||||
@ -203,7 +255,10 @@ async def update_sensitive(
|
|||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise Exception(f"更新敏感信息记录失败: {str(e)}") from e
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"更新敏感信息记录失败: {str(e)}"
|
||||||
|
) from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
@ -251,14 +306,20 @@ async def delete_sensitive(
|
|||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise Exception(f"删除敏感信息记录失败: {str(e)}") from e
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"删除敏感信息记录失败: {str(e)}"
|
||||||
|
) from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 6. 业务辅助函数:获取所有敏感词(供其他模块调用)
|
||||||
|
# ------------------------------
|
||||||
def get_all_sensitive_words() -> list[str]:
|
def get_all_sensitive_words() -> list[str]:
|
||||||
"""
|
"""
|
||||||
获取所有敏感词、返回字符串数组
|
获取所有敏感词(返回纯字符串列表,用于过滤业务)
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
list[str]: 包含所有敏感词的数组
|
list[str]: 包含所有敏感词的数组
|
||||||
@ -273,17 +334,17 @@ def get_all_sensitive_words() -> list[str]:
|
|||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
# 执行查询、只获取敏感词字段
|
# 执行查询(只获取敏感词字段,按ID排序)
|
||||||
query = "SELECT name FROM sensitives ORDER BY id"
|
query = "SELECT name FROM sensitives ORDER BY id"
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
sensitive_records = cursor.fetchall()
|
sensitive_records = cursor.fetchall()
|
||||||
|
|
||||||
# 提取敏感词到数组中
|
# 提取敏感词到纯字符串数组
|
||||||
return [record['name'] for record in sensitive_records]
|
return [record['name'] for record in sensitive_records]
|
||||||
|
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
# 数据库错误处理
|
# 数据库错误向上抛出,由调用方处理
|
||||||
raise MySQLError(f"查询敏感词失败: {str(e)}") from e
|
raise MySQLError(f"查询敏感词列表失败: {str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
# 确保资源正确释放
|
# 确保数据库连接正确释放
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
Reference in New Issue
Block a user