模型添加置信度设置,敏感词分页

This commit is contained in:
2025-09-15 17:43:36 +08:00
parent d9192bd964
commit 5959f9994c
4 changed files with 156 additions and 81 deletions

View File

@ -1,5 +1,5 @@
from ultralytics import YOLO from ultralytics import YOLO
from service.model_service import get_current_yolo_model # 带版本校验的模型获取 from service.model_service import get_current_yolo_model, get_current_conf_threshold # 新增置信度获取函数
def load_model(model_path=None): def load_model(model_path=None):
@ -15,8 +15,8 @@ def load_model(model_path=None):
return None return None
def detect(frame, conf_threshold=0.7): def detect(frame):
"""执行目标检测(仅模型版本变化时重新加载,平时复用缓存""" """执行目标检测(使用动态置信度,仅模型版本变化时重新加载)"""
# 获取模型(内部已做版本校验,未变化则直接返回缓存) # 获取模型(内部已做版本校验,未变化则直接返回缓存)
current_model = load_model() current_model = load_model()
if not current_model: if not current_model:
@ -26,6 +26,8 @@ def detect(frame, conf_threshold=0.7):
return (False, "无效输入帧") return (False, "无效输入帧")
try: try:
# 获取动态置信度(从全局配置中读取)
conf_threshold = get_current_conf_threshold()
# 用当前模型执行检测(复用缓存,无额外加载耗时) # 用当前模型执行检测(复用缓存,无额外加载耗时)
results = current_model(frame, conf=conf_threshold, verbose=False) results = current_model(frame, conf=conf_threshold, verbose=False)
has_results = len(results[0].boxes) > 0 if results else False has_results = len(results[0].boxes) > 0 if results else False
@ -43,13 +45,8 @@ def detect(frame, conf_threshold=0.7):
class_name = current_model.names[cls] if hasattr(current_model, 'names') else f"类别{cls}" class_name = current_model.names[cls] if hasattr(current_model, 'names') else f"类别{cls}"
result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox}") result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox}")
# 打印当前使用的模型路径和版本(用于验证)
# model_path = getattr(current_model, "model_path", "未知路径")
# from service.model_service import _current_model_version
# print(f"[YOLO检测] 使用模型:{model_path}(版本:{_current_model_version[:10]}...")
return (True, "; ".join(result_parts)) return (True, "; ".join(result_parts))
except Exception as e: except Exception as e:
print(f"YOLO检测过程出错{str(e)}") print(f"YOLO检测过程出错{str(e)}")
return (False, f"检测错误:{str(e)}") return (False, f"检测错误:{str(e)}")

View File

@ -1,5 +1,6 @@
from datetime import datetime from datetime import datetime
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Optional
# ------------------------------ # ------------------------------
@ -7,24 +8,31 @@ from pydantic import BaseModel, Field
# ------------------------------ # ------------------------------
class SensitiveCreateRequest(BaseModel): class SensitiveCreateRequest(BaseModel):
"""创建敏感信息记录请求模型""" """创建敏感信息记录请求模型"""
# 移除了id字段、由数据库自动生成 name: str = Field(..., max_length=255, description="敏感词内容(必填)")
name: str = Field(None, max_length=255, description="名称")
class SensitiveUpdateRequest(BaseModel): class SensitiveUpdateRequest(BaseModel):
"""更新敏感信息记录请求模型""" """更新敏感信息记录请求模型"""
name: str = Field(None, max_length=255, description="名称") name: Optional[str] = Field(None, max_length=255, description="敏感词内容(可选修改)")
# ------------------------------ # ------------------------------
# 响应模型(后端返回数据) # 响应模型(后端返回数据)
# ------------------------------ # ------------------------------
class SensitiveResponse(BaseModel): class SensitiveResponse(BaseModel):
"""敏感信息记录响应模型""" """敏感信息单条记录响应模型"""
id: int = Field(..., description="主键ID") # 响应中仍然包含ID id: int = Field(..., description="主键ID")
name: str = Field(None, description="名称") name: str = Field(..., description="敏感词内容")
created_at: datetime = Field(..., description="记录创建时间") created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间") updated_at: datetime = Field(..., description="记录更新时间")
# 支持从数据库查询结果转换 # 支持从数据库查询结果(字典/对象)自动转换
model_config = {"from_attributes": True} model_config = {"from_attributes": True}
class SensitiveListResponse(BaseModel):
"""敏感信息分页列表响应模型(新增)"""
total: int = Field(..., description="敏感词总记录数")
sensitives: List[SensitiveResponse] = Field(..., description="当前页敏感词列表")
model_config = {"from_attributes": True}

View File

@ -31,15 +31,16 @@ DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep
ALLOWED_MODEL_EXT = {"pt"} ALLOWED_MODEL_EXT = {"pt"}
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
# 全局模型变量(带版本标识) # 全局模型变量(带版本标识和置信度
global _yolo_model, _current_model_version global _yolo_model, _current_model_version, _current_conf_threshold
_yolo_model = None _yolo_model = None
_current_model_version = None # 模型版本标识(用于检测模型是否变化) _current_model_version = None # 模型版本标识
_current_conf_threshold = 0.8 # 默认置信度初始值
router = APIRouter(prefix="/models", tags=["模型管理"]) router = APIRouter(prefix="/models", tags=["模型管理"])
# 服务重启核心工具函数 # 服务重启核心工具函数(保持不变)
def restart_service(): def restart_service():
"""重启当前FastAPI服务进程""" """重启当前FastAPI服务进程"""
print("\n[服务重启] 检测到默认模型更换,开始清理资源并重启...") print("\n[服务重启] 检测到默认模型更换,开始清理资源并重启...")
@ -87,7 +88,7 @@ def restart_service():
raise HTTPException(status_code=500, detail=f"模型更换成功,但服务重启失败:{str(e)}") from e raise HTTPException(status_code=500, detail=f"模型更换成功,但服务重启失败:{str(e)}") from e
# 模型路径验证工具函数 # 模型路径验证工具函数(保持不变)
def get_valid_model_abs_path(relative_path: str) -> str: def get_valid_model_abs_path(relative_path: str) -> str:
try: try:
relative_path = relative_path.replace("/", os.sep) relative_path = relative_path.replace("/", os.sep)
@ -139,7 +140,7 @@ def get_valid_model_abs_path(relative_path: str) -> str:
) from e ) from e
# 对外提供当前模型(带版本校验) # 对外提供当前模型(带版本校验)(保持不变)
def get_current_yolo_model(): def get_current_yolo_model():
"""供检测模块获取当前最新默认模型(仅版本变化时重新加载)""" """供检测模块获取当前最新默认模型(仅版本变化时重新加载)"""
global _yolo_model, _current_model_version global _yolo_model, _current_model_version
@ -155,21 +156,19 @@ def get_current_yolo_model():
return None return None
# 1. 计算当前默认模型的唯一版本标识 # 1. 计算当前默认模型的唯一版本标识
# (路径哈希 + 文件修改时间戳,确保模型变化时版本变化)
valid_abs_path = get_valid_model_abs_path(default_model["path"]) valid_abs_path = get_valid_model_abs_path(default_model["path"])
model_stat = os.stat(valid_abs_path) model_stat = os.stat(valid_abs_path)
model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
# 2. 版本未变化则复用已有模型(核心优化点) # 2. 版本未变化则复用已有模型
if _yolo_model and _current_model_version == model_version: if _yolo_model and _current_model_version == model_version:
# print(f"[get_current_yolo_model] 模型版本未变,复用缓存(版本:{_current_model_version[:10]}...")
return _yolo_model return _yolo_model
# 3. 版本变化时重新加载模型 # 3. 版本变化时重新加载模型
_yolo_model = load_yolo_model(valid_abs_path) _yolo_model = load_yolo_model(valid_abs_path)
if _yolo_model: if _yolo_model:
setattr(_yolo_model, "model_path", valid_abs_path) setattr(_yolo_model, "model_path", valid_abs_path)
_current_model_version = model_version # 更新版本标识 _current_model_version = model_version
print(f"[get_current_yolo_model] 模型版本更新,重新加载(新版本:{model_version[:10]}...") print(f"[get_current_yolo_model] 模型版本更新,重新加载(新版本:{model_version[:10]}...")
else: else:
print(f"[get_current_yolo_model] 加载最新默认模型失败:{valid_abs_path}") print(f"[get_current_yolo_model] 加载最新默认模型失败:{valid_abs_path}")
@ -182,7 +181,14 @@ def get_current_yolo_model():
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# 1. 上传模型 # 新增:获取当前置信度阈值
def get_current_conf_threshold():
"""供检测模块获取当前设置的置信度阈值"""
global _current_conf_threshold
return _current_conf_threshold
# 1. 上传模型(保持不变)
@router.post("", response_model=APIResponse, summary="上传YOLO模型.pt格式") @router.post("", response_model=APIResponse, summary="上传YOLO模型.pt格式")
async def upload_model( async def upload_model(
name: str = Form(..., description="模型名称"), name: str = Form(..., description="模型名称"),
@ -255,7 +261,7 @@ async def upload_model(
return APIResponse( return APIResponse(
code=201, code=201,
message=f"模型上传成功ID{new_model['id']}", message=f"模型上传成功ID{new_model['id']}",
data=ModelResponse(**new_model) data=ModelResponse(** new_model)
) )
except MySQLError as e: except MySQLError as e:
@ -273,7 +279,7 @@ async def upload_model(
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# 2. 获取模型列表 # 2. 获取模型列表(保持不变)
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)") @router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
async def get_model_list( async def get_model_list(
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
@ -319,7 +325,7 @@ async def get_model_list(
message=f"获取成功!共{total}条记录", message=f"获取成功!共{total}条记录",
data=ModelListResponse( data=ModelListResponse(
total=total, total=total,
models=[ModelResponse(**model) for model in model_list] models=[ModelResponse(** model) for model in model_list]
) )
) )
@ -329,7 +335,7 @@ async def get_model_list(
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# 3. 获取默认模型 # 3. 获取默认模型(保持不变)
@router.get("/default", response_model=APIResponse, summary="获取当前默认模型") @router.get("/default", response_model=APIResponse, summary="获取当前默认模型")
async def get_default_model(): async def get_default_model():
conn = None conn = None
@ -371,7 +377,7 @@ async def get_default_model():
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# 4. 获取单个模型详情 # 4. 获取单个模型详情(保持不变)
@router.get("/{model_id}", response_model=APIResponse, summary="获取单个模型详情") @router.get("/{model_id}", response_model=APIResponse, summary="获取单个模型详情")
async def get_model(model_id: int): async def get_model(model_id: int):
conn = None conn = None
@ -392,7 +398,7 @@ async def get_model(model_id: int):
return APIResponse( return APIResponse(
code=200, code=200,
message=f"查询成功,但路径异常:{e.detail}", message=f"查询成功,但路径异常:{e.detail}",
data=ModelResponse(**model) data=ModelResponse(** model)
) )
return APIResponse( return APIResponse(
@ -400,14 +406,13 @@ async def get_model(model_id: int):
message="查询成功", message="查询成功",
data=ModelResponse(**model) data=ModelResponse(**model)
) )
except MySQLError as e: except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# 5. 更新模型信息 # 5. 更新模型信息(保持不变)
@router.put("/{model_id}", response_model=APIResponse, summary="更新模型信息") @router.put("/{model_id}", response_model=APIResponse, summary="更新模型信息")
async def update_model(model_id: int, model_update: ModelUpdateRequest): async def update_model(model_id: int, model_update: ModelUpdateRequest):
conn = None conn = None
@ -479,7 +484,7 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest):
return APIResponse( return APIResponse(
code=200, code=200,
message="模型更新成功", message="模型更新成功",
data=ModelResponse(**updated_model) data=ModelResponse(** updated_model)
) )
except MySQLError as e: except MySQLError as e:
@ -490,9 +495,12 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest):
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# 5.1 更换默认模型(自动重启服务 # 5.1 更换默认模型(添加置信度参数
@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型(自动重启服务)") @router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型(自动重启服务)")
async def set_default_model(model_id: int): async def set_default_model(
model_id: int,
conf_threshold: float = Query(0.8, ge=0.01, le=0.99, description="模型检测置信度阈值0.01-0.99")
):
conn = None conn = None
cursor = None cursor = None
try: try:
@ -551,10 +559,11 @@ async def set_default_model(model_id: int):
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
updated_model = cursor.fetchone() updated_model = cursor.fetchone()
# 7. 重置版本标识(关键:确保下次检测加载新模型) # 7. 重置版本标识和更新置信度
global _current_model_version global _current_model_version, _current_conf_threshold
_current_model_version = None _current_model_version = None
print(f"[更换默认模型] 已重置模型版本标识,下次检测将加载新模型") _current_conf_threshold = conf_threshold # 保存动态置信度
print(f"[更换默认模型] 已重置模型版本标识,设置新置信度:{conf_threshold}")
# 8. 延迟重启服务 # 8. 延迟重启服务
print(f"[更换默认模型] 成功将在1秒后重启服务以应用新模型ID{model_id}") print(f"[更换默认模型] 成功将在1秒后重启服务以应用新模型ID{model_id}")
@ -566,8 +575,8 @@ async def set_default_model(model_id: int):
# 9. 返回成功响应 # 9. 返回成功响应
return APIResponse( return APIResponse(
code=200, code=200,
message=f"已成功更换默认模型ID{model_id}服务将在1秒后自动重启以应用新模型", message=f"已成功更换默认模型ID{model_id},置信度:{conf_threshold}服务将在1秒后自动重启以应用新模型",
data=ModelResponse(**updated_model) data=ModelResponse(** updated_model)
) )
except MySQLError as e: except MySQLError as e:
@ -580,7 +589,7 @@ async def set_default_model(model_id: int):
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# 6. 删除模型 # 6. 删除模型(保持不变)
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型") @router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
async def delete_model(model_id: int): async def delete_model(model_id: int):
conn = None conn = None
@ -639,7 +648,7 @@ async def delete_model(model_id: int):
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# 7. 下载模型文件 # 7. 下载模型文件(保持不变)
@router.get("/{model_id}/download", summary="下载模型文件") @router.get("/{model_id}/download", summary="下载模型文件")
async def download_model(model_id: int): async def download_model(model_id: int):
conn = None conn = None
@ -665,4 +674,4 @@ async def download_model(model_id: int):
except MySQLError as e: except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)

View File

@ -1,8 +1,14 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException, Query
from mysql.connector import Error as MySQLError from mysql.connector import Error as MySQLError
from typing import Optional
from ds.db import db from ds.db import db
from schema.sensitive_schema import SensitiveCreateRequest, SensitiveUpdateRequest, SensitiveResponse from schema.sensitive_schema import (
SensitiveCreateRequest,
SensitiveUpdateRequest,
SensitiveResponse,
SensitiveListResponse # 导入新增的分页响应模型
)
from schema.response_schema import APIResponse from schema.response_schema import APIResponse
from middle.auth_middleware import get_current_user from middle.auth_middleware import get_current_user
from schema.user_schema import UserResponse from schema.user_schema import UserResponse
@ -19,9 +25,11 @@ router = APIRouter(
# ------------------------------ # ------------------------------
@router.post("", response_model=APIResponse, summary="创建敏感信息记录") @router.post("", response_model=APIResponse, summary="创建敏感信息记录")
async def create_sensitive( async def create_sensitive(
sensitive: SensitiveCreateRequest): # 添加了登录认证依赖 sensitive: SensitiveCreateRequest,
current_user: UserResponse = Depends(get_current_user) # 补充登录认证依赖(与其他接口保持一致)
):
""" """
创建敏感信息记录: 创建敏感信息记录:
- 需登录认证 - 需登录认证
- 插入新的敏感信息记录到数据库ID由数据库自动生成 - 插入新的敏感信息记录到数据库ID由数据库自动生成
- 返回创建成功信息 - 返回创建成功信息
@ -34,8 +42,8 @@ async def create_sensitive(
# 插入新敏感信息记录到数据库不包含ID、由数据库自动生成 # 插入新敏感信息记录到数据库不包含ID、由数据库自动生成
insert_query = """ insert_query = """
INSERT INTO sensitives (name) INSERT INTO sensitives (name, created_at, updated_at)
VALUES (%s) VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
""" """
cursor.execute(insert_query, (sensitive.name,)) cursor.execute(insert_query, (sensitive.name,))
conn.commit() conn.commit()
@ -56,12 +64,14 @@ async def create_sensitive(
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"创建敏感信息记录失败: {str(e)}") from e raise HTTPException(
status_code=500,
detail=f"创建敏感信息记录失败: {str(e)}"
) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# 以下接口代码保持不变
# ------------------------------ # ------------------------------
# 2. 获取单个敏感信息记录 # 2. 获取单个敏感信息记录
# ------------------------------ # ------------------------------
@ -71,7 +81,7 @@ async def get_sensitive(
current_user: UserResponse = Depends(get_current_user) # 需登录认证 current_user: UserResponse = Depends(get_current_user) # 需登录认证
): ):
""" """
获取单个敏感信息记录: 获取单个敏感信息记录:
- 需登录认证 - 需登录认证
- 根据ID查询敏感信息记录 - 根据ID查询敏感信息记录
- 返回查询到的敏感信息 - 返回查询到的敏感信息
@ -98,21 +108,29 @@ async def get_sensitive(
data=SensitiveResponse(**sensitive) data=SensitiveResponse(**sensitive)
) )
except MySQLError as e: except MySQLError as e:
raise Exception(f"查询敏感信息记录失败: {str(e)}") from e raise HTTPException(
status_code=500,
detail=f"查询敏感信息记录失败: {str(e)}"
) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# ------------------------------ # ------------------------------
# 3. 获取所有敏感信息记录 # 3. 获取敏感信息分页列表(重构:支持分页+关键词搜索)
# ------------------------------ # ------------------------------
@router.get("", response_model=APIResponse, summary="获取所有敏感信息记录") @router.get("", response_model=APIResponse, summary="获取敏感信息分页列表(支持关键词搜索)")
async def get_all_sensitives(): async def get_sensitive_list(
page: int = Query(1, ge=1, description="页码默认1最小1"),
page_size: int = Query(10, ge=1, le=100, description="每页条数默认101-100"),
name: Optional[str] = Query(None, description="敏感词关键词搜索(模糊匹配)"),
current_user: UserResponse = Depends(get_current_user) # 需登录认证
):
""" """
获取所有敏感信息记录: 获取敏感信息分页列表:
- 需登录认证 - 需登录认证
- 查询所有敏感信息记录(不需要分页 - 支持分页page/page_size和敏感词关键词模糊搜索name
- 返回所有敏感信息列表 - 返回总记录数+当前页数据
""" """
conn = None conn = None
cursor = None cursor = None
@ -120,17 +138,49 @@ async def get_all_sensitives():
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
query = "SELECT * FROM sensitives ORDER BY id" # 1. 构建查询条件(支持关键词搜索)
cursor.execute(query) where_clause = []
sensitives = cursor.fetchall() params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%") # 模糊匹配关键词
# 2. 查询总记录数(用于分页计算)
count_sql = "SELECT COUNT(*) AS total FROM sensitives"
if where_clause:
count_sql += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_sql, params.copy()) # 复制参数列表,避免后续污染
total = cursor.fetchone()["total"]
# 3. 计算分页偏移量
offset = (page - 1) * page_size
# 4. 分页查询敏感词数据(按更新时间倒序,最新的在前)
list_sql = "SELECT * FROM sensitives"
if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause)
# 排序+分页LIMIT 条数 OFFSET 偏移量)
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
# 补充分页参数page_size和offset
params.extend([page_size, offset])
cursor.execute(list_sql, params)
sensitive_list = cursor.fetchall()
# 5. 构造分页响应数据
return APIResponse( return APIResponse(
code=200, code=200,
message="所有敏感信息记录查询成功", message=f"敏感信息列表查询成功(共{total}条记录,当前第{page}页)",
data=[SensitiveResponse(**sensitive) for sensitive in sensitives] data=SensitiveListResponse(
total=total,
sensitives=[SensitiveResponse(**item) for item in sensitive_list]
)
) )
except MySQLError as e: except MySQLError as e:
raise Exception(f"查询所有敏感信息记录失败: {str(e)}") from e raise HTTPException(
status_code=500,
detail=f"查询敏感信息列表失败: {str(e)}"
) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -145,7 +195,7 @@ async def update_sensitive(
current_user: UserResponse = Depends(get_current_user) # 需登录认证 current_user: UserResponse = Depends(get_current_user) # 需登录认证
): ):
""" """
更新敏感信息记录: 更新敏感信息记录:
- 需登录认证 - 需登录认证
- 根据ID更新敏感信息记录 - 根据ID更新敏感信息记录
- 返回更新后的敏感信息 - 返回更新后的敏感信息
@ -177,14 +227,16 @@ async def update_sensitive(
if not update_fields: if not update_fields:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="至少需要提供一个字段进行更新" detail="至少需要提供一个字段进行更新name"
) )
params.append(sensitive_id) # WHERE条件参数 # 补充更新时间和WHERE条件参数
update_fields.append("updated_at = CURRENT_TIMESTAMP")
params.append(sensitive_id)
update_query = f""" update_query = f"""
UPDATE sensitives UPDATE sensitives
SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP SET {', '.join(update_fields)}
WHERE id = %s WHERE id = %s
""" """
cursor.execute(update_query, params) cursor.execute(update_query, params)
@ -203,7 +255,10 @@ async def update_sensitive(
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"更新敏感信息记录失败: {str(e)}") from e raise HTTPException(
status_code=500,
detail=f"更新敏感信息记录失败: {str(e)}"
) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
@ -217,7 +272,7 @@ async def delete_sensitive(
current_user: UserResponse = Depends(get_current_user) # 需登录认证 current_user: UserResponse = Depends(get_current_user) # 需登录认证
): ):
""" """
删除敏感信息记录: 删除敏感信息记录:
- 需登录认证 - 需登录认证
- 根据ID删除敏感信息记录 - 根据ID删除敏感信息记录
- 返回删除成功信息 - 返回删除成功信息
@ -251,14 +306,20 @@ async def delete_sensitive(
except MySQLError as e: except MySQLError as e:
if conn: if conn:
conn.rollback() conn.rollback()
raise Exception(f"删除敏感信息记录失败: {str(e)}") from e raise HTTPException(
status_code=500,
detail=f"删除敏感信息记录失败: {str(e)}"
) from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# ------------------------------
# 6. 业务辅助函数:获取所有敏感词(供其他模块调用)
# ------------------------------
def get_all_sensitive_words() -> list[str]: def get_all_sensitive_words() -> list[str]:
""" """
获取所有敏感词返回字符串数组 获取所有敏感词返回字符串列表,用于过滤业务)
返回: 返回:
list[str]: 包含所有敏感词的数组 list[str]: 包含所有敏感词的数组
@ -273,17 +334,17 @@ def get_all_sensitive_words() -> list[str]:
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 执行查询只获取敏感词字段 # 执行查询只获取敏感词字段按ID排序
query = "SELECT name FROM sensitives ORDER BY id" query = "SELECT name FROM sensitives ORDER BY id"
cursor.execute(query) cursor.execute(query)
sensitive_records = cursor.fetchall() sensitive_records = cursor.fetchall()
# 提取敏感词到数组 # 提取敏感词到纯字符串数组
return [record['name'] for record in sensitive_records] return [record['name'] for record in sensitive_records]
except MySQLError as e: except MySQLError as e:
# 数据库错误处理 # 数据库错误向上抛出,由调用方处理
raise MySQLError(f"查询敏感词失败: {str(e)}") from e raise MySQLError(f"查询敏感词列表失败: {str(e)}") from e
finally: finally:
# 确保资源正确释放 # 确保数据库连接正确释放
db.close_connection(conn, cursor) db.close_connection(conn, cursor)