diff --git a/core/yolo.py b/core/yolo.py index 1f58648..ec14f6e 100644 --- a/core/yolo.py +++ b/core/yolo.py @@ -1,5 +1,5 @@ from ultralytics import YOLO -from service.model_service import get_current_yolo_model # 带版本校验的模型获取 +from service.model_service import get_current_yolo_model, get_current_conf_threshold # 新增置信度获取函数 def load_model(model_path=None): @@ -15,8 +15,8 @@ def load_model(model_path=None): return None -def detect(frame, conf_threshold=0.7): - """执行目标检测(仅模型版本变化时重新加载,平时复用缓存)""" +def detect(frame): + """执行目标检测(使用动态置信度,仅模型版本变化时重新加载)""" # 获取模型(内部已做版本校验,未变化则直接返回缓存) current_model = load_model() if not current_model: @@ -26,6 +26,8 @@ def detect(frame, conf_threshold=0.7): return (False, "无效输入帧") try: + # 获取动态置信度(从全局配置中读取) + conf_threshold = get_current_conf_threshold() # 用当前模型执行检测(复用缓存,无额外加载耗时) results = current_model(frame, conf=conf_threshold, verbose=False) has_results = len(results[0].boxes) > 0 if results else False @@ -43,13 +45,8 @@ def detect(frame, conf_threshold=0.7): class_name = current_model.names[cls] if hasattr(current_model, 'names') else f"类别{cls}" result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox})") - # 打印当前使用的模型路径和版本(用于验证) - # model_path = getattr(current_model, "model_path", "未知路径") - # from service.model_service import _current_model_version - # print(f"[YOLO检测] 使用模型:{model_path}(版本:{_current_model_version[:10]}...)") - return (True, "; ".join(result_parts)) except Exception as e: print(f"YOLO检测过程出错:{str(e)}") - return (False, f"检测错误:{str(e)}") + return (False, f"检测错误:{str(e)}") \ No newline at end of file diff --git a/schema/sensitive_schema.py b/schema/sensitive_schema.py index f0100a7..656159e 100644 --- a/schema/sensitive_schema.py +++ b/schema/sensitive_schema.py @@ -1,5 +1,6 @@ from datetime import datetime from pydantic import BaseModel, Field +from typing import List, Optional # ------------------------------ @@ -7,24 +8,31 @@ from pydantic import BaseModel, Field # ------------------------------ class SensitiveCreateRequest(BaseModel): """创建敏感信息记录请求模型""" - # 移除了id字段、由数据库自动生成 - name: str = Field(None, max_length=255, description="名称") + name: str = Field(..., max_length=255, description="敏感词内容(必填)") class SensitiveUpdateRequest(BaseModel): """更新敏感信息记录请求模型""" - name: str = Field(None, max_length=255, description="名称") + name: Optional[str] = Field(None, max_length=255, description="敏感词内容(可选修改)") # ------------------------------ # 响应模型(后端返回数据) # ------------------------------ class SensitiveResponse(BaseModel): - """敏感信息记录响应模型""" - id: int = Field(..., description="主键ID") # 响应中仍然包含ID - name: str = Field(None, description="名称") + """敏感信息单条记录响应模型""" + id: int = Field(..., description="主键ID") + name: str = Field(..., description="敏感词内容") created_at: datetime = Field(..., description="记录创建时间") updated_at: datetime = Field(..., description="记录更新时间") - # 支持从数据库查询结果转换 + # 支持从数据库查询结果(字典/对象)自动转换 model_config = {"from_attributes": True} + + +class SensitiveListResponse(BaseModel): + """敏感信息分页列表响应模型(新增)""" + total: int = Field(..., description="敏感词总记录数") + sensitives: List[SensitiveResponse] = Field(..., description="当前页敏感词列表") + + model_config = {"from_attributes": True} \ No newline at end of file diff --git a/service/model_service.py b/service/model_service.py index 599920e..7603710 100644 --- a/service/model_service.py +++ b/service/model_service.py @@ -31,15 +31,16 @@ DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep ALLOWED_MODEL_EXT = {"pt"} MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB -# 全局模型变量(带版本标识) -global _yolo_model, _current_model_version +# 全局模型变量(带版本标识和置信度) +global _yolo_model, _current_model_version, _current_conf_threshold _yolo_model = None -_current_model_version = None # 模型版本标识(用于检测模型是否变化) +_current_model_version = None # 模型版本标识 +_current_conf_threshold = 0.8 # 默认置信度初始值 router = APIRouter(prefix="/models", tags=["模型管理"]) -# 服务重启核心工具函数 +# 服务重启核心工具函数(保持不变) def restart_service(): """重启当前FastAPI服务进程""" print("\n[服务重启] 检测到默认模型更换,开始清理资源并重启...") @@ -87,7 +88,7 @@ def restart_service(): raise HTTPException(status_code=500, detail=f"模型更换成功,但服务重启失败:{str(e)}") from e -# 模型路径验证工具函数 +# 模型路径验证工具函数(保持不变) def get_valid_model_abs_path(relative_path: str) -> str: try: relative_path = relative_path.replace("/", os.sep) @@ -139,7 +140,7 @@ def get_valid_model_abs_path(relative_path: str) -> str: ) from e -# 对外提供当前模型(带版本校验) +# 对外提供当前模型(带版本校验)(保持不变) def get_current_yolo_model(): """供检测模块获取当前最新默认模型(仅版本变化时重新加载)""" global _yolo_model, _current_model_version @@ -155,21 +156,19 @@ def get_current_yolo_model(): return None # 1. 计算当前默认模型的唯一版本标识 - # (路径哈希 + 文件修改时间戳,确保模型变化时版本变化) valid_abs_path = get_valid_model_abs_path(default_model["path"]) model_stat = os.stat(valid_abs_path) model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" - # 2. 版本未变化则复用已有模型(核心优化点) + # 2. 版本未变化则复用已有模型 if _yolo_model and _current_model_version == model_version: - # print(f"[get_current_yolo_model] 模型版本未变,复用缓存(版本:{_current_model_version[:10]}...)") return _yolo_model # 3. 版本变化时重新加载模型 _yolo_model = load_yolo_model(valid_abs_path) if _yolo_model: setattr(_yolo_model, "model_path", valid_abs_path) - _current_model_version = model_version # 更新版本标识 + _current_model_version = model_version print(f"[get_current_yolo_model] 模型版本更新,重新加载(新版本:{model_version[:10]}...)") else: print(f"[get_current_yolo_model] 加载最新默认模型失败:{valid_abs_path}") @@ -182,7 +181,14 @@ def get_current_yolo_model(): db.close_connection(conn, cursor) -# 1. 上传模型 +# 新增:获取当前置信度阈值 +def get_current_conf_threshold(): + """供检测模块获取当前设置的置信度阈值""" + global _current_conf_threshold + return _current_conf_threshold + + +# 1. 上传模型(保持不变) @router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)") async def upload_model( name: str = Form(..., description="模型名称"), @@ -255,7 +261,7 @@ async def upload_model( return APIResponse( code=201, message=f"模型上传成功!ID:{new_model['id']}", - data=ModelResponse(**new_model) + data=ModelResponse(** new_model) ) except MySQLError as e: @@ -273,7 +279,7 @@ async def upload_model( db.close_connection(conn, cursor) -# 2. 获取模型列表 +# 2. 获取模型列表(保持不变) @router.get("", response_model=APIResponse, summary="获取模型列表(分页)") async def get_model_list( page: int = Query(1, ge=1), @@ -319,7 +325,7 @@ async def get_model_list( message=f"获取成功!共{total}条记录", data=ModelListResponse( total=total, - models=[ModelResponse(**model) for model in model_list] + models=[ModelResponse(** model) for model in model_list] ) ) @@ -329,7 +335,7 @@ async def get_model_list( db.close_connection(conn, cursor) -# 3. 获取默认模型 +# 3. 获取默认模型(保持不变) @router.get("/default", response_model=APIResponse, summary="获取当前默认模型") async def get_default_model(): conn = None @@ -371,7 +377,7 @@ async def get_default_model(): db.close_connection(conn, cursor) -# 4. 获取单个模型详情 +# 4. 获取单个模型详情(保持不变) @router.get("/{model_id}", response_model=APIResponse, summary="获取单个模型详情") async def get_model(model_id: int): conn = None @@ -392,7 +398,7 @@ async def get_model(model_id: int): return APIResponse( code=200, message=f"查询成功,但路径异常:{e.detail}", - data=ModelResponse(**model) + data=ModelResponse(** model) ) return APIResponse( @@ -400,14 +406,13 @@ async def get_model(model_id: int): message="查询成功", data=ModelResponse(**model) ) - except MySQLError as e: raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e finally: db.close_connection(conn, cursor) -# 5. 更新模型信息 +# 5. 更新模型信息(保持不变) @router.put("/{model_id}", response_model=APIResponse, summary="更新模型信息") async def update_model(model_id: int, model_update: ModelUpdateRequest): conn = None @@ -479,7 +484,7 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest): return APIResponse( code=200, message="模型更新成功", - data=ModelResponse(**updated_model) + data=ModelResponse(** updated_model) ) except MySQLError as e: @@ -490,9 +495,12 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest): db.close_connection(conn, cursor) -# 5.1 更换默认模型(自动重启服务) +# 5.1 更换默认模型(添加置信度参数) @router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型(自动重启服务)") -async def set_default_model(model_id: int): +async def set_default_model( + model_id: int, + conf_threshold: float = Query(0.8, ge=0.01, le=0.99, description="模型检测置信度阈值(0.01-0.99)") +): conn = None cursor = None try: @@ -551,10 +559,11 @@ async def set_default_model(model_id: int): cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) updated_model = cursor.fetchone() - # 7. 重置版本标识(关键:确保下次检测加载新模型) - global _current_model_version + # 7. 重置版本标识和更新置信度 + global _current_model_version, _current_conf_threshold _current_model_version = None - print(f"[更换默认模型] 已重置模型版本标识,下次检测将加载新模型") + _current_conf_threshold = conf_threshold # 保存动态置信度 + print(f"[更换默认模型] 已重置模型版本标识,设置新置信度:{conf_threshold}") # 8. 延迟重启服务 print(f"[更换默认模型] 成功!将在1秒后重启服务以应用新模型(ID:{model_id})") @@ -566,8 +575,8 @@ async def set_default_model(model_id: int): # 9. 返回成功响应 return APIResponse( code=200, - message=f"已成功更换默认模型(ID:{model_id})!服务将在1秒后自动重启以应用新模型", - data=ModelResponse(**updated_model) + message=f"已成功更换默认模型(ID:{model_id}),置信度:{conf_threshold}!服务将在1秒后自动重启以应用新模型", + data=ModelResponse(** updated_model) ) except MySQLError as e: @@ -580,7 +589,7 @@ async def set_default_model(model_id: int): db.close_connection(conn, cursor) -# 6. 删除模型 +# 6. 删除模型(保持不变) @router.delete("/{model_id}", response_model=APIResponse, summary="删除模型") async def delete_model(model_id: int): conn = None @@ -639,7 +648,7 @@ async def delete_model(model_id: int): db.close_connection(conn, cursor) -# 7. 下载模型文件 +# 7. 下载模型文件(保持不变) @router.get("/{model_id}/download", summary="下载模型文件") async def download_model(model_id: int): conn = None @@ -665,4 +674,4 @@ async def download_model(model_id: int): except MySQLError as e: raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e finally: - db.close_connection(conn, cursor) + db.close_connection(conn, cursor) \ No newline at end of file diff --git a/service/sensitive_service.py b/service/sensitive_service.py index 95232cc..de958af 100644 --- a/service/sensitive_service.py +++ b/service/sensitive_service.py @@ -1,8 +1,14 @@ -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query from mysql.connector import Error as MySQLError +from typing import Optional from ds.db import db -from schema.sensitive_schema import SensitiveCreateRequest, SensitiveUpdateRequest, SensitiveResponse +from schema.sensitive_schema import ( + SensitiveCreateRequest, + SensitiveUpdateRequest, + SensitiveResponse, + SensitiveListResponse # 导入新增的分页响应模型 +) from schema.response_schema import APIResponse from middle.auth_middleware import get_current_user from schema.user_schema import UserResponse @@ -19,9 +25,11 @@ router = APIRouter( # ------------------------------ @router.post("", response_model=APIResponse, summary="创建敏感信息记录") async def create_sensitive( - sensitive: SensitiveCreateRequest): # 添加了登录认证依赖 + sensitive: SensitiveCreateRequest, + current_user: UserResponse = Depends(get_current_user) # 补充登录认证依赖(与其他接口保持一致) +): """ - 创建敏感信息记录: + 创建敏感信息记录: - 需登录认证 - 插入新的敏感信息记录到数据库(ID由数据库自动生成) - 返回创建成功信息 @@ -34,8 +42,8 @@ async def create_sensitive( # 插入新敏感信息记录到数据库(不包含ID、由数据库自动生成) insert_query = """ - INSERT INTO sensitives (name) - VALUES (%s) + INSERT INTO sensitives (name, created_at, updated_at) + VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) """ cursor.execute(insert_query, (sensitive.name,)) conn.commit() @@ -56,12 +64,14 @@ async def create_sensitive( except MySQLError as e: if conn: conn.rollback() - raise Exception(f"创建敏感信息记录失败: {str(e)}") from e + raise HTTPException( + status_code=500, + detail=f"创建敏感信息记录失败: {str(e)}" + ) from e finally: db.close_connection(conn, cursor) -# 以下接口代码保持不变 # ------------------------------ # 2. 获取单个敏感信息记录 # ------------------------------ @@ -71,7 +81,7 @@ async def get_sensitive( current_user: UserResponse = Depends(get_current_user) # 需登录认证 ): """ - 获取单个敏感信息记录: + 获取单个敏感信息记录: - 需登录认证 - 根据ID查询敏感信息记录 - 返回查询到的敏感信息 @@ -98,21 +108,29 @@ async def get_sensitive( data=SensitiveResponse(**sensitive) ) except MySQLError as e: - raise Exception(f"查询敏感信息记录失败: {str(e)}") from e + raise HTTPException( + status_code=500, + detail=f"查询敏感信息记录失败: {str(e)}" + ) from e finally: db.close_connection(conn, cursor) # ------------------------------ -# 3. 获取所有敏感信息记录 +# 3. 获取敏感信息分页列表(重构:支持分页+关键词搜索) # ------------------------------ -@router.get("", response_model=APIResponse, summary="获取所有敏感信息记录") -async def get_all_sensitives(): +@router.get("", response_model=APIResponse, summary="获取敏感信息分页列表(支持关键词搜索)") +async def get_sensitive_list( + page: int = Query(1, ge=1, description="页码(默认1,最小1)"), + page_size: int = Query(10, ge=1, le=100, description="每页条数(默认10,1-100)"), + name: Optional[str] = Query(None, description="敏感词关键词搜索(模糊匹配)"), + current_user: UserResponse = Depends(get_current_user) # 需登录认证 +): """ - 获取所有敏感信息记录: + 获取敏感信息分页列表: - 需登录认证 - - 查询所有敏感信息记录(不需要分页) - - 返回所有敏感信息列表 + - 支持分页(page/page_size)和敏感词关键词模糊搜索(name) + - 返回总记录数+当前页数据 """ conn = None cursor = None @@ -120,17 +138,49 @@ async def get_all_sensitives(): conn = db.get_connection() cursor = conn.cursor(dictionary=True) - query = "SELECT * FROM sensitives ORDER BY id" - cursor.execute(query) - sensitives = cursor.fetchall() + # 1. 构建查询条件(支持关键词搜索) + where_clause = [] + params = [] + if name: + where_clause.append("name LIKE %s") + params.append(f"%{name}%") # 模糊匹配关键词 + # 2. 查询总记录数(用于分页计算) + count_sql = "SELECT COUNT(*) AS total FROM sensitives" + if where_clause: + count_sql += " WHERE " + " AND ".join(where_clause) + cursor.execute(count_sql, params.copy()) # 复制参数列表,避免后续污染 + total = cursor.fetchone()["total"] + + # 3. 计算分页偏移量 + offset = (page - 1) * page_size + + # 4. 分页查询敏感词数据(按更新时间倒序,最新的在前) + list_sql = "SELECT * FROM sensitives" + if where_clause: + list_sql += " WHERE " + " AND ".join(where_clause) + # 排序+分页(LIMIT 条数 OFFSET 偏移量) + list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s" + # 补充分页参数(page_size和offset) + params.extend([page_size, offset]) + + cursor.execute(list_sql, params) + sensitive_list = cursor.fetchall() + + # 5. 构造分页响应数据 return APIResponse( code=200, - message="所有敏感信息记录查询成功", - data=[SensitiveResponse(**sensitive) for sensitive in sensitives] + message=f"敏感信息列表查询成功(共{total}条记录,当前第{page}页)", + data=SensitiveListResponse( + total=total, + sensitives=[SensitiveResponse(**item) for item in sensitive_list] + ) ) except MySQLError as e: - raise Exception(f"查询所有敏感信息记录失败: {str(e)}") from e + raise HTTPException( + status_code=500, + detail=f"查询敏感信息列表失败: {str(e)}" + ) from e finally: db.close_connection(conn, cursor) @@ -145,7 +195,7 @@ async def update_sensitive( current_user: UserResponse = Depends(get_current_user) # 需登录认证 ): """ - 更新敏感信息记录: + 更新敏感信息记录: - 需登录认证 - 根据ID更新敏感信息记录 - 返回更新后的敏感信息 @@ -177,14 +227,16 @@ async def update_sensitive( if not update_fields: raise HTTPException( status_code=400, - detail="至少需要提供一个字段进行更新" + detail="至少需要提供一个字段进行更新(如:name)" ) - params.append(sensitive_id) # WHERE条件的参数 + # 补充更新时间和WHERE条件参数 + update_fields.append("updated_at = CURRENT_TIMESTAMP") + params.append(sensitive_id) update_query = f""" UPDATE sensitives - SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP + SET {', '.join(update_fields)} WHERE id = %s """ cursor.execute(update_query, params) @@ -203,7 +255,10 @@ async def update_sensitive( except MySQLError as e: if conn: conn.rollback() - raise Exception(f"更新敏感信息记录失败: {str(e)}") from e + raise HTTPException( + status_code=500, + detail=f"更新敏感信息记录失败: {str(e)}" + ) from e finally: db.close_connection(conn, cursor) @@ -217,7 +272,7 @@ async def delete_sensitive( current_user: UserResponse = Depends(get_current_user) # 需登录认证 ): """ - 删除敏感信息记录: + 删除敏感信息记录: - 需登录认证 - 根据ID删除敏感信息记录 - 返回删除成功信息 @@ -251,14 +306,20 @@ async def delete_sensitive( except MySQLError as e: if conn: conn.rollback() - raise Exception(f"删除敏感信息记录失败: {str(e)}") from e + raise HTTPException( + status_code=500, + detail=f"删除敏感信息记录失败: {str(e)}" + ) from e finally: db.close_connection(conn, cursor) +# ------------------------------ +# 6. 业务辅助函数:获取所有敏感词(供其他模块调用) +# ------------------------------ def get_all_sensitive_words() -> list[str]: """ - 获取所有敏感词、返回字符串数组 + 获取所有敏感词(返回纯字符串列表,用于过滤业务) 返回: list[str]: 包含所有敏感词的数组 @@ -273,17 +334,17 @@ def get_all_sensitive_words() -> list[str]: conn = db.get_connection() cursor = conn.cursor(dictionary=True) - # 执行查询、只获取敏感词字段 + # 执行查询(只获取敏感词字段,按ID排序) query = "SELECT name FROM sensitives ORDER BY id" cursor.execute(query) sensitive_records = cursor.fetchall() - # 提取敏感词到数组中 + # 提取敏感词到纯字符串数组 return [record['name'] for record in sensitive_records] except MySQLError as e: - # 数据库错误处理 - raise MySQLError(f"查询敏感词失败: {str(e)}") from e + # 数据库错误向上抛出,由调用方处理 + raise MySQLError(f"查询敏感词列表失败: {str(e)}") from e finally: - # 确保资源正确释放 + # 确保数据库连接正确释放 db.close_connection(conn, cursor) \ No newline at end of file