| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  | import subprocess | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | import os | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  | import sys | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | import shutil | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  | import threading | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | from pathlib import Path | 
					
						
							|  |  |  |  | from datetime import datetime | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  | from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query | 
					
						
							|  |  |  |  | from fastapi.responses import FileResponse | 
					
						
							|  |  |  |  | from mysql.connector import Error as MySQLError | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | # 复用项目依赖 | 
					
						
							|  |  |  |  | from ds.db import db | 
					
						
							| 
									
										
										
										
											2025-09-15 18:35:43 +08:00
										 |  |  |  | from encryption.encrypt_decorator import encrypt_response | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | from schema.model_schema import ( | 
					
						
							|  |  |  |  |     ModelCreateRequest, | 
					
						
							|  |  |  |  |     ModelUpdateRequest, | 
					
						
							|  |  |  |  |     ModelResponse, | 
					
						
							|  |  |  |  |     ModelListResponse | 
					
						
							|  |  |  |  | ) | 
					
						
							|  |  |  |  | from schema.response_schema import APIResponse | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  | from util.model_util import load_yolo_model  # 模型加载工具 | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | # 路径配置 | 
					
						
							|  |  |  |  | CURRENT_FILE_PATH = Path(__file__).resolve() | 
					
						
							|  |  |  |  | PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent | 
					
						
							|  |  |  |  | MODEL_SAVE_ROOT = PROJECT_ROOT / "resource" / "models" | 
					
						
							|  |  |  |  | MODEL_SAVE_ROOT.mkdir(exist_ok=True, parents=True) | 
					
						
							|  |  |  |  | DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | # 模型限制 | 
					
						
							|  |  |  |  | ALLOWED_MODEL_EXT = {"pt"} | 
					
						
							|  |  |  |  | MAX_MODEL_SIZE = 100 * 1024 * 1024  # 100MB | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  | # 全局模型变量(带版本标识和置信度) | 
					
						
							|  |  |  |  | global _yolo_model, _current_model_version, _current_conf_threshold | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | _yolo_model = None | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  | _current_model_version = None  # 模型版本标识 | 
					
						
							|  |  |  |  | _current_conf_threshold = 0.8  # 默认置信度初始值 | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | router = APIRouter(prefix="/models", tags=["模型管理"]) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  | # 服务重启核心工具函数(保持不变) | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  | def restart_service(): | 
					
						
							|  |  |  |  |     """重启当前FastAPI服务进程""" | 
					
						
							|  |  |  |  |     print("\n[服务重启] 检测到默认模型更换,开始清理资源并重启...") | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         # 关闭所有WebSocket连接 | 
					
						
							|  |  |  |  |         try: | 
					
						
							|  |  |  |  |             from ws import connected_clients | 
					
						
							|  |  |  |  |             if connected_clients: | 
					
						
							|  |  |  |  |                 print(f"[服务重启] 关闭{len(connected_clients)}个WebSocket旧连接") | 
					
						
							|  |  |  |  |                 for ip, conn in list(connected_clients.items()): | 
					
						
							|  |  |  |  |                     try: | 
					
						
							|  |  |  |  |                         if conn.consumer_task and not conn.consumer_task.done(): | 
					
						
							|  |  |  |  |                             conn.consumer_task.cancel() | 
					
						
							|  |  |  |  |                         conn.websocket.close(code=1001, reason="模型更新,服务重启") | 
					
						
							|  |  |  |  |                         connected_clients.pop(ip) | 
					
						
							|  |  |  |  |                     except Exception as e: | 
					
						
							|  |  |  |  |                         print(f"[服务重启] 关闭客户端{ip}连接失败:{str(e)}") | 
					
						
							|  |  |  |  |         except ImportError: | 
					
						
							|  |  |  |  |             print("[服务重启] 未找到WebSocket连接管理模块,跳过连接关闭") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 关闭数据库连接 | 
					
						
							|  |  |  |  |         if hasattr(db, "close_all_connections"): | 
					
						
							|  |  |  |  |             db.close_all_connections() | 
					
						
							|  |  |  |  |         else: | 
					
						
							|  |  |  |  |             print("[警告] db模块未实现close_all_connections,可能存在连接泄漏") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 启动新进程 | 
					
						
							|  |  |  |  |         python_exec = sys.executable | 
					
						
							|  |  |  |  |         current_argv = sys.argv | 
					
						
							|  |  |  |  |         print(f"[服务重启] 启动新进程:{python_exec} {' '.join(current_argv)}") | 
					
						
							|  |  |  |  |         subprocess.Popen( | 
					
						
							|  |  |  |  |             [python_exec] + current_argv, | 
					
						
							|  |  |  |  |             close_fds=True, | 
					
						
							|  |  |  |  |             start_new_session=True, | 
					
						
							|  |  |  |  |             stdout=subprocess.PIPE, | 
					
						
							|  |  |  |  |             stderr=subprocess.PIPE | 
					
						
							|  |  |  |  |         ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 退出当前进程 | 
					
						
							|  |  |  |  |         print("[服务重启] 新进程已启动,当前进程退出") | 
					
						
							|  |  |  |  |         sys.exit(0) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     except Exception as e: | 
					
						
							|  |  |  |  |         print(f"[服务重启] 重启失败:{str(e)}") | 
					
						
							|  |  |  |  |         raise HTTPException(status_code=500, detail=f"模型更换成功,但服务重启失败:{str(e)}") from e | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  | # 模型路径验证工具函数(保持不变) | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | def get_valid_model_abs_path(relative_path: str) -> str: | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         relative_path = relative_path.replace("/", os.sep) | 
					
						
							|  |  |  |  |         model_abs_path = PROJECT_ROOT / relative_path | 
					
						
							|  |  |  |  |         model_abs_path = model_abs_path.resolve() | 
					
						
							|  |  |  |  |         model_abs_path_str = str(model_abs_path) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         if not model_abs_path_str.startswith(str(MODEL_SAVE_ROOT)): | 
					
						
							|  |  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |  |                 status_code=400, | 
					
						
							|  |  |  |  |                 detail=f"模型路径非法!允许目录:{str(MODEL_SAVE_ROOT)},当前路径:{model_abs_path_str}" | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         if not model_abs_path.exists(): | 
					
						
							|  |  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |  |                 status_code=404, | 
					
						
							|  |  |  |  |                 detail=f"模型文件不存在!路径:{model_abs_path_str}" | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         if not model_abs_path.is_file(): | 
					
						
							|  |  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |  |                 status_code=400, | 
					
						
							|  |  |  |  |                 detail=f"路径不是文件!路径:{model_abs_path_str}" | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         file_size = model_abs_path.stat().st_size | 
					
						
							|  |  |  |  |         if file_size > MAX_MODEL_SIZE: | 
					
						
							|  |  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |  |                 status_code=400, | 
					
						
							|  |  |  |  |                 detail=f"模型文件过大({file_size // 1024 // 1024}MB),超过限制{MAX_MODEL_SIZE // 1024 // 1024}MB" | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         file_ext = model_abs_path.suffix.lower() | 
					
						
							|  |  |  |  |         if file_ext not in [f".{ext}" for ext in ALLOWED_MODEL_EXT]: | 
					
						
							|  |  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |  |                 status_code=400, | 
					
						
							|  |  |  |  |                 detail=f"模型格式非法(仅支持{ALLOWED_MODEL_EXT})!当前格式:{file_ext}" | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         print(f"[模型路径校验] 成功!路径:{model_abs_path_str},大小:{file_size // 1024}KB") | 
					
						
							|  |  |  |  |         return model_abs_path_str | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     except HTTPException as e: | 
					
						
							|  |  |  |  |         raise e | 
					
						
							|  |  |  |  |     except Exception as e: | 
					
						
							|  |  |  |  |         raise HTTPException( | 
					
						
							|  |  |  |  |             status_code=500, | 
					
						
							|  |  |  |  |             detail=f"路径处理失败:{str(e)}" | 
					
						
							|  |  |  |  |         ) from e | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  | # 对外提供当前模型(带版本校验)(保持不变) | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  | def get_current_yolo_model(): | 
					
						
							|  |  |  |  |     """供检测模块获取当前最新默认模型(仅版本变化时重新加载)""" | 
					
						
							|  |  |  |  |     global _yolo_model, _current_model_version | 
					
						
							|  |  |  |  |     conn = None | 
					
						
							|  |  |  |  |     cursor = None | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         conn = db.get_connection() | 
					
						
							|  |  |  |  |         cursor = conn.cursor(dictionary=True) | 
					
						
							|  |  |  |  |         cursor.execute("SELECT path FROM model WHERE is_default = 1") | 
					
						
							|  |  |  |  |         default_model = cursor.fetchone() | 
					
						
							|  |  |  |  |         if not default_model: | 
					
						
							|  |  |  |  |             print("[get_current_yolo_model] 暂无默认模型") | 
					
						
							|  |  |  |  |             return None | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 1. 计算当前默认模型的唯一版本标识 | 
					
						
							|  |  |  |  |         valid_abs_path = get_valid_model_abs_path(default_model["path"]) | 
					
						
							|  |  |  |  |         model_stat = os.stat(valid_abs_path) | 
					
						
							|  |  |  |  |         model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  |         # 2. 版本未变化则复用已有模型 | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  |         if _yolo_model and _current_model_version == model_version: | 
					
						
							|  |  |  |  |             return _yolo_model | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 3. 版本变化时重新加载模型 | 
					
						
							|  |  |  |  |         _yolo_model = load_yolo_model(valid_abs_path) | 
					
						
							|  |  |  |  |         if _yolo_model: | 
					
						
							|  |  |  |  |             setattr(_yolo_model, "model_path", valid_abs_path) | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  |             _current_model_version = model_version | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  |             print(f"[get_current_yolo_model] 模型版本更新,重新加载(新版本:{model_version[:10]}...)") | 
					
						
							|  |  |  |  |         else: | 
					
						
							|  |  |  |  |             print(f"[get_current_yolo_model] 加载最新默认模型失败:{valid_abs_path}") | 
					
						
							|  |  |  |  |         return _yolo_model | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     except Exception as e: | 
					
						
							|  |  |  |  |         print(f"[get_current_yolo_model] 加载失败:{str(e)}") | 
					
						
							|  |  |  |  |         return None | 
					
						
							|  |  |  |  |     finally: | 
					
						
							|  |  |  |  |         db.close_connection(conn, cursor) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  | # 新增:获取当前置信度阈值 | 
					
						
							|  |  |  |  | def get_current_conf_threshold(): | 
					
						
							|  |  |  |  |     """供检测模块获取当前设置的置信度阈值""" | 
					
						
							|  |  |  |  |     global _current_conf_threshold | 
					
						
							|  |  |  |  |     return _current_conf_threshold | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | # 1. 上传模型(保持不变) | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | @router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)") | 
					
						
							| 
									
										
										
										
											2025-09-15 18:35:43 +08:00
										 |  |  |  | @encrypt_response() | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | async def upload_model( | 
					
						
							|  |  |  |  |         name: str = Form(..., description="模型名称"), | 
					
						
							|  |  |  |  |         description: str = Form(None, description="模型描述"), | 
					
						
							|  |  |  |  |         is_default: bool = Form(False, description="是否设为默认模型"), | 
					
						
							|  |  |  |  |         file: UploadFile = File(..., description=f"YOLO模型文件(.pt,最大{MAX_MODEL_SIZE // 1024 // 1024}MB)") | 
					
						
							|  |  |  |  | ): | 
					
						
							|  |  |  |  |     conn = None | 
					
						
							|  |  |  |  |     cursor = None | 
					
						
							|  |  |  |  |     saved_file_path = None | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         # 校验文件 | 
					
						
							|  |  |  |  |         file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else "" | 
					
						
							|  |  |  |  |         if file_ext not in ALLOWED_MODEL_EXT: | 
					
						
							|  |  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |  |                 status_code=400, | 
					
						
							|  |  |  |  |                 detail=f"仅支持{ALLOWED_MODEL_EXT}格式,当前:{file_ext}" | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  |         if file.size > MAX_MODEL_SIZE: | 
					
						
							|  |  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |  |                 status_code=400, | 
					
						
							|  |  |  |  |                 detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB,当前{file.size // 1024 // 1024}MB" | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 保存文件 | 
					
						
							|  |  |  |  |         timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | 
					
						
							|  |  |  |  |         safe_filename = f"model_{timestamp}_{file.filename.replace(' ', '_')}" | 
					
						
							|  |  |  |  |         saved_file_path = MODEL_SAVE_ROOT / safe_filename | 
					
						
							|  |  |  |  |         with open(saved_file_path, "wb") as f: | 
					
						
							|  |  |  |  |             shutil.copyfileobj(file.file, f) | 
					
						
							|  |  |  |  |         saved_file_path.chmod(0o644)  # 设置权限 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 数据库路径处理 | 
					
						
							|  |  |  |  |         db_relative_path = str(saved_file_path).replace(DB_PATH_PREFIX_TO_REMOVE, "").replace(os.sep, "/") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 数据库操作 | 
					
						
							|  |  |  |  |         conn = db.get_connection() | 
					
						
							|  |  |  |  |         cursor = conn.cursor(dictionary=True) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         if is_default: | 
					
						
							|  |  |  |  |             cursor.execute("UPDATE model SET is_default = 0") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         insert_sql = """
 | 
					
						
							|  |  |  |  |             INSERT INTO model (name, path, is_default, description, file_size) | 
					
						
							|  |  |  |  |             VALUES (%s, %s, %s, %s, %s) | 
					
						
							|  |  |  |  |         """
 | 
					
						
							|  |  |  |  |         cursor.execute(insert_sql, (name, db_relative_path, 1 if is_default else 0, description, file.size)) | 
					
						
							|  |  |  |  |         conn.commit() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()") | 
					
						
							|  |  |  |  |         new_model = cursor.fetchone() | 
					
						
							|  |  |  |  |         if not new_model: | 
					
						
							|  |  |  |  |             raise HTTPException(status_code=500, detail="上传成功但无法获取记录") | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  |         # 加载默认模型并更新版本 | 
					
						
							|  |  |  |  |         global _yolo_model, _current_model_version | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |         if is_default: | 
					
						
							|  |  |  |  |             valid_abs_path = get_valid_model_abs_path(db_relative_path) | 
					
						
							|  |  |  |  |             _yolo_model = load_yolo_model(valid_abs_path) | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  |             if _yolo_model: | 
					
						
							|  |  |  |  |                 setattr(_yolo_model, "model_path", valid_abs_path) | 
					
						
							|  |  |  |  |                 model_stat = os.stat(valid_abs_path) | 
					
						
							|  |  |  |  |                 _current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" | 
					
						
							|  |  |  |  |             else: | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |                 raise HTTPException( | 
					
						
							|  |  |  |  |                     status_code=500, | 
					
						
							|  |  |  |  |                     detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path})" | 
					
						
							|  |  |  |  |                 ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         return APIResponse( | 
					
						
							|  |  |  |  |             code=201, | 
					
						
							|  |  |  |  |             message=f"模型上传成功!ID:{new_model['id']}", | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  |             data=ModelResponse(** new_model) | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |         ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     except MySQLError as e: | 
					
						
							|  |  |  |  |         if conn: | 
					
						
							|  |  |  |  |             conn.rollback() | 
					
						
							|  |  |  |  |         if saved_file_path and saved_file_path.exists(): | 
					
						
							|  |  |  |  |             saved_file_path.unlink() | 
					
						
							|  |  |  |  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | 
					
						
							|  |  |  |  |     except Exception as e: | 
					
						
							|  |  |  |  |         if saved_file_path and saved_file_path.exists(): | 
					
						
							|  |  |  |  |             saved_file_path.unlink() | 
					
						
							|  |  |  |  |         raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e | 
					
						
							|  |  |  |  |     finally: | 
					
						
							|  |  |  |  |         await file.close() | 
					
						
							|  |  |  |  |         db.close_connection(conn, cursor) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  | # 2. 获取模型列表(保持不变) | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | @router.get("", response_model=APIResponse, summary="获取模型列表(分页)") | 
					
						
							| 
									
										
										
										
											2025-09-15 18:35:43 +08:00
										 |  |  |  | @encrypt_response() | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | async def get_model_list( | 
					
						
							|  |  |  |  |         page: int = Query(1, ge=1), | 
					
						
							|  |  |  |  |         page_size: int = Query(10, ge=1, le=100), | 
					
						
							|  |  |  |  |         name: str = Query(None), | 
					
						
							|  |  |  |  |         is_default: bool = Query(None) | 
					
						
							|  |  |  |  | ): | 
					
						
							|  |  |  |  |     conn = None | 
					
						
							|  |  |  |  |     cursor = None | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         conn = db.get_connection() | 
					
						
							|  |  |  |  |         cursor = conn.cursor(dictionary=True) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         where_clause = [] | 
					
						
							|  |  |  |  |         params = [] | 
					
						
							|  |  |  |  |         if name: | 
					
						
							|  |  |  |  |             where_clause.append("name LIKE %s") | 
					
						
							|  |  |  |  |             params.append(f"%{name}%") | 
					
						
							|  |  |  |  |         if is_default is not None: | 
					
						
							|  |  |  |  |             where_clause.append("is_default = %s") | 
					
						
							|  |  |  |  |             params.append(1 if is_default else 0) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 总记录数 | 
					
						
							|  |  |  |  |         count_sql = "SELECT COUNT(*) AS total FROM model" | 
					
						
							|  |  |  |  |         if where_clause: | 
					
						
							|  |  |  |  |             count_sql += " WHERE " + " AND ".join(where_clause) | 
					
						
							|  |  |  |  |         cursor.execute(count_sql, params) | 
					
						
							|  |  |  |  |         total = cursor.fetchone()["total"] | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 分页数据 | 
					
						
							|  |  |  |  |         offset = (page - 1) * page_size | 
					
						
							|  |  |  |  |         list_sql = "SELECT * FROM model" | 
					
						
							|  |  |  |  |         if where_clause: | 
					
						
							|  |  |  |  |             list_sql += " WHERE " + " AND ".join(where_clause) | 
					
						
							|  |  |  |  |         list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s" | 
					
						
							|  |  |  |  |         params.extend([page_size, offset]) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         cursor.execute(list_sql, params) | 
					
						
							|  |  |  |  |         model_list = cursor.fetchall() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         return APIResponse( | 
					
						
							|  |  |  |  |             code=200, | 
					
						
							|  |  |  |  |             message=f"获取成功!共{total}条记录", | 
					
						
							|  |  |  |  |             data=ModelListResponse( | 
					
						
							|  |  |  |  |                 total=total, | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  |                 models=[ModelResponse(** model) for model in model_list] | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |             ) | 
					
						
							|  |  |  |  |         ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     except MySQLError as e: | 
					
						
							|  |  |  |  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | 
					
						
							|  |  |  |  |     finally: | 
					
						
							|  |  |  |  |         db.close_connection(conn, cursor) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  | # 3. 获取默认模型(保持不变) | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | @router.get("/default", response_model=APIResponse, summary="获取当前默认模型") | 
					
						
							| 
									
										
										
										
											2025-09-15 18:35:43 +08:00
										 |  |  |  | @encrypt_response() | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | async def get_default_model(): | 
					
						
							|  |  |  |  |     conn = None | 
					
						
							|  |  |  |  |     cursor = None | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         conn = db.get_connection() | 
					
						
							|  |  |  |  |         cursor = conn.cursor(dictionary=True) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         cursor.execute("SELECT * FROM model WHERE is_default = 1") | 
					
						
							|  |  |  |  |         default_model = cursor.fetchone() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         if not default_model: | 
					
						
							|  |  |  |  |             raise HTTPException(status_code=404, detail="暂无默认模型") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         valid_abs_path = get_valid_model_abs_path(default_model["path"]) | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  |         global _yolo_model, _current_model_version | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |         if not _yolo_model: | 
					
						
							|  |  |  |  |             _yolo_model = load_yolo_model(valid_abs_path) | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  |             if _yolo_model: | 
					
						
							|  |  |  |  |                 setattr(_yolo_model, "model_path", valid_abs_path) | 
					
						
							|  |  |  |  |                 model_stat = os.stat(valid_abs_path) | 
					
						
							|  |  |  |  |                 _current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" | 
					
						
							|  |  |  |  |             else: | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |                 raise HTTPException( | 
					
						
							|  |  |  |  |                     status_code=500, | 
					
						
							|  |  |  |  |                     detail=f"默认模型存在,但加载失败(路径:{valid_abs_path})" | 
					
						
							|  |  |  |  |                 ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         return APIResponse( | 
					
						
							|  |  |  |  |             code=200, | 
					
						
							|  |  |  |  |             message="默认模型查询成功", | 
					
						
							|  |  |  |  |             data=ModelResponse(**default_model) | 
					
						
							|  |  |  |  |         ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     except MySQLError as e: | 
					
						
							|  |  |  |  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | 
					
						
							|  |  |  |  |     finally: | 
					
						
							|  |  |  |  |         db.close_connection(conn, cursor) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  | # 4. 获取单个模型详情(保持不变) | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | @router.get("/{model_id}", response_model=APIResponse, summary="获取单个模型详情") | 
					
						
							| 
									
										
										
										
											2025-09-15 18:35:43 +08:00
										 |  |  |  | @encrypt_response() | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | async def get_model(model_id: int): | 
					
						
							|  |  |  |  |     conn = None | 
					
						
							|  |  |  |  |     cursor = None | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         conn = db.get_connection() | 
					
						
							|  |  |  |  |         cursor = conn.cursor(dictionary=True) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | 
					
						
							|  |  |  |  |         model = cursor.fetchone() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         if not model: | 
					
						
							|  |  |  |  |             raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         try: | 
					
						
							|  |  |  |  |             model_abs_path = get_valid_model_abs_path(model["path"]) | 
					
						
							|  |  |  |  |         except HTTPException as e: | 
					
						
							|  |  |  |  |             return APIResponse( | 
					
						
							|  |  |  |  |                 code=200, | 
					
						
							|  |  |  |  |                 message=f"查询成功,但路径异常:{e.detail}", | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  |                 data=ModelResponse(** model) | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |             ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         return APIResponse( | 
					
						
							|  |  |  |  |             code=200, | 
					
						
							|  |  |  |  |             message="查询成功", | 
					
						
							|  |  |  |  |             data=ModelResponse(**model) | 
					
						
							|  |  |  |  |         ) | 
					
						
							|  |  |  |  |     except MySQLError as e: | 
					
						
							|  |  |  |  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | 
					
						
							|  |  |  |  |     finally: | 
					
						
							|  |  |  |  |         db.close_connection(conn, cursor) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  | # 5. 更新模型信息(保持不变) | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | @router.put("/{model_id}", response_model=APIResponse, summary="更新模型信息") | 
					
						
							| 
									
										
										
										
											2025-09-15 18:35:43 +08:00
										 |  |  |  | @encrypt_response() | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | async def update_model(model_id: int, model_update: ModelUpdateRequest): | 
					
						
							|  |  |  |  |     conn = None | 
					
						
							|  |  |  |  |     cursor = None | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         conn = db.get_connection() | 
					
						
							|  |  |  |  |         cursor = conn.cursor(dictionary=True) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | 
					
						
							|  |  |  |  |         exist_model = cursor.fetchone() | 
					
						
							|  |  |  |  |         if not exist_model: | 
					
						
							|  |  |  |  |             raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         update_fields = [] | 
					
						
							|  |  |  |  |         params = [] | 
					
						
							|  |  |  |  |         if model_update.name is not None: | 
					
						
							|  |  |  |  |             update_fields.append("name = %s") | 
					
						
							|  |  |  |  |             params.append(model_update.name) | 
					
						
							|  |  |  |  |         if model_update.description is not None: | 
					
						
							|  |  |  |  |             update_fields.append("description = %s") | 
					
						
							|  |  |  |  |             params.append(model_update.description) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         need_load_default = False | 
					
						
							|  |  |  |  |         if model_update.is_default is not None: | 
					
						
							|  |  |  |  |             if model_update.is_default: | 
					
						
							|  |  |  |  |                 cursor.execute("UPDATE model SET is_default = 0") | 
					
						
							|  |  |  |  |                 update_fields.append("is_default = 1") | 
					
						
							|  |  |  |  |                 need_load_default = True | 
					
						
							|  |  |  |  |             else: | 
					
						
							|  |  |  |  |                 cursor.execute("SELECT COUNT(*) AS cnt FROM model WHERE is_default = 1") | 
					
						
							|  |  |  |  |                 default_count = cursor.fetchone()["cnt"] | 
					
						
							|  |  |  |  |                 if default_count == 1 and exist_model["is_default"]: | 
					
						
							|  |  |  |  |                     raise HTTPException( | 
					
						
							|  |  |  |  |                         status_code=400, | 
					
						
							|  |  |  |  |                         detail="当前是唯一默认模型,不可取消!" | 
					
						
							|  |  |  |  |                     ) | 
					
						
							|  |  |  |  |                 update_fields.append("is_default = 0") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         if not update_fields: | 
					
						
							|  |  |  |  |             raise HTTPException(status_code=400, detail="至少需提供一个更新字段") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         params.append(model_id) | 
					
						
							|  |  |  |  |         update_sql = f"""
 | 
					
						
							|  |  |  |  |             UPDATE model  | 
					
						
							|  |  |  |  |             SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP  | 
					
						
							|  |  |  |  |             WHERE id = %s | 
					
						
							|  |  |  |  |         """
 | 
					
						
							|  |  |  |  |         cursor.execute(update_sql, params) | 
					
						
							|  |  |  |  |         conn.commit() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | 
					
						
							|  |  |  |  |         updated_model = cursor.fetchone() | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  |         # 更新模型后重置版本标识 | 
					
						
							|  |  |  |  |         global _yolo_model, _current_model_version | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |         if need_load_default: | 
					
						
							|  |  |  |  |             valid_abs_path = get_valid_model_abs_path(updated_model["path"]) | 
					
						
							|  |  |  |  |             _yolo_model = load_yolo_model(valid_abs_path) | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  |             if _yolo_model: | 
					
						
							|  |  |  |  |                 setattr(_yolo_model, "model_path", valid_abs_path) | 
					
						
							|  |  |  |  |                 model_stat = os.stat(valid_abs_path) | 
					
						
							|  |  |  |  |                 _current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" | 
					
						
							|  |  |  |  |             else: | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |                 raise HTTPException( | 
					
						
							|  |  |  |  |                     status_code=500, | 
					
						
							|  |  |  |  |                     detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path})" | 
					
						
							|  |  |  |  |                 ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         return APIResponse( | 
					
						
							|  |  |  |  |             code=200, | 
					
						
							|  |  |  |  |             message="模型更新成功", | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  |             data=ModelResponse(** updated_model) | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |         ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     except MySQLError as e: | 
					
						
							|  |  |  |  |         if conn: | 
					
						
							|  |  |  |  |             conn.rollback() | 
					
						
							|  |  |  |  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | 
					
						
							|  |  |  |  |     finally: | 
					
						
							|  |  |  |  |         db.close_connection(conn, cursor) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  | # 5.1 更换默认模型(添加置信度参数) | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  | @router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型(自动重启服务)") | 
					
						
							| 
									
										
										
										
											2025-09-15 18:35:43 +08:00
										 |  |  |  | @encrypt_response() | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  | async def set_default_model( | 
					
						
							|  |  |  |  |     model_id: int, | 
					
						
							|  |  |  |  |     conf_threshold: float = Query(0.8, ge=0.01, le=0.99, description="模型检测置信度阈值(0.01-0.99)") | 
					
						
							|  |  |  |  | ): | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  |     conn = None | 
					
						
							|  |  |  |  |     cursor = None | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         conn = db.get_connection() | 
					
						
							|  |  |  |  |         cursor = conn.cursor(dictionary=True) | 
					
						
							|  |  |  |  |         conn.autocommit = False  # 开启事务 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 1. 校验目标模型是否存在 | 
					
						
							|  |  |  |  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | 
					
						
							|  |  |  |  |         target_model = cursor.fetchone() | 
					
						
							|  |  |  |  |         if not target_model: | 
					
						
							|  |  |  |  |             raise HTTPException(status_code=404, detail=f"目标模型不存在!ID:{model_id}") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 2. 检查是否已为默认模型 | 
					
						
							|  |  |  |  |         if target_model["is_default"]: | 
					
						
							|  |  |  |  |             return APIResponse( | 
					
						
							|  |  |  |  |                 code=200, | 
					
						
							|  |  |  |  |                 message=f"模型ID:{model_id} 已是默认模型,无需更换和重启", | 
					
						
							|  |  |  |  |                 data=ModelResponse(**target_model) | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 3. 校验目标模型文件合法性 | 
					
						
							|  |  |  |  |         try: | 
					
						
							|  |  |  |  |             valid_abs_path = get_valid_model_abs_path(target_model["path"]) | 
					
						
							|  |  |  |  |         except HTTPException as e: | 
					
						
							|  |  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |  |                 status_code=400, | 
					
						
							|  |  |  |  |                 detail=f"目标模型文件非法,无法设为默认:{e.detail}" | 
					
						
							|  |  |  |  |             ) from e | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 4. 数据库事务:更新默认模型状态 | 
					
						
							|  |  |  |  |         try: | 
					
						
							|  |  |  |  |             cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP") | 
					
						
							|  |  |  |  |             cursor.execute( | 
					
						
							|  |  |  |  |                 "UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s", | 
					
						
							|  |  |  |  |                 (model_id,) | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  |             conn.commit() | 
					
						
							|  |  |  |  |         except MySQLError as e: | 
					
						
							|  |  |  |  |             conn.rollback() | 
					
						
							|  |  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |  |                 status_code=500, | 
					
						
							|  |  |  |  |                 detail=f"更新默认模型状态失败(已回滚):{str(e)}" | 
					
						
							|  |  |  |  |             ) from e | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 5. 验证新模型可加载性 | 
					
						
							|  |  |  |  |         test_model = load_yolo_model(valid_abs_path) | 
					
						
							|  |  |  |  |         if not test_model: | 
					
						
							|  |  |  |  |             conn.rollback() | 
					
						
							|  |  |  |  |             raise HTTPException( | 
					
						
							|  |  |  |  |                 status_code=500, | 
					
						
							|  |  |  |  |                 detail=f"新默认模型加载失败,已回滚状态(路径:{valid_abs_path})" | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 6. 重新查询更新后的模型信息 | 
					
						
							|  |  |  |  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | 
					
						
							|  |  |  |  |         updated_model = cursor.fetchone() | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  |         # 7. 重置版本标识和更新置信度 | 
					
						
							|  |  |  |  |         global _current_model_version, _current_conf_threshold | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  |         _current_model_version = None | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  |         _current_conf_threshold = conf_threshold  # 保存动态置信度 | 
					
						
							|  |  |  |  |         print(f"[更换默认模型] 已重置模型版本标识,设置新置信度:{conf_threshold}") | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 8. 延迟重启服务 | 
					
						
							|  |  |  |  |         print(f"[更换默认模型] 成功!将在1秒后重启服务以应用新模型(ID:{model_id})") | 
					
						
							|  |  |  |  |         threading.Timer( | 
					
						
							|  |  |  |  |             interval=1.0, | 
					
						
							|  |  |  |  |             function=restart_service | 
					
						
							|  |  |  |  |         ).start() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 9. 返回成功响应 | 
					
						
							|  |  |  |  |         return APIResponse( | 
					
						
							|  |  |  |  |             code=200, | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  |             message=f"已成功更换默认模型(ID:{model_id}),置信度:{conf_threshold}!服务将在1秒后自动重启以应用新模型", | 
					
						
							|  |  |  |  |             data=ModelResponse(** updated_model) | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  |         ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     except MySQLError as e: | 
					
						
							|  |  |  |  |         if conn: | 
					
						
							|  |  |  |  |             conn.rollback() | 
					
						
							|  |  |  |  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | 
					
						
							|  |  |  |  |     finally: | 
					
						
							|  |  |  |  |         if conn: | 
					
						
							|  |  |  |  |             conn.autocommit = True | 
					
						
							|  |  |  |  |         db.close_connection(conn, cursor) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  | # 6. 删除模型(保持不变) | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | @router.delete("/{model_id}", response_model=APIResponse, summary="删除模型") | 
					
						
							| 
									
										
										
										
											2025-09-15 18:35:43 +08:00
										 |  |  |  | @encrypt_response() | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | async def delete_model(model_id: int): | 
					
						
							|  |  |  |  |     conn = None | 
					
						
							|  |  |  |  |     cursor = None | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         conn = db.get_connection() | 
					
						
							|  |  |  |  |         cursor = conn.cursor(dictionary=True) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | 
					
						
							|  |  |  |  |         exist_model = cursor.fetchone() | 
					
						
							|  |  |  |  |         if not exist_model: | 
					
						
							|  |  |  |  |             raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}") | 
					
						
							|  |  |  |  |         if exist_model["is_default"]: | 
					
						
							|  |  |  |  |             raise HTTPException(status_code=400, detail="默认模型不可删除!") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         try: | 
					
						
							|  |  |  |  |             model_abs_path_str = get_valid_model_abs_path(exist_model["path"]) | 
					
						
							|  |  |  |  |             model_abs_path = Path(model_abs_path_str) | 
					
						
							|  |  |  |  |         except HTTPException as e: | 
					
						
							|  |  |  |  |             cursor.execute("DELETE FROM model WHERE id = %s", (model_id,)) | 
					
						
							|  |  |  |  |             conn.commit() | 
					
						
							|  |  |  |  |             return APIResponse( | 
					
						
							|  |  |  |  |                 code=200, | 
					
						
							|  |  |  |  |                 message=f"记录删除成功,文件异常:{e.detail}", | 
					
						
							|  |  |  |  |                 data=None | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         cursor.execute("DELETE FROM model WHERE id = %s", (model_id,)) | 
					
						
							|  |  |  |  |         conn.commit() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         extra_msg = "" | 
					
						
							|  |  |  |  |         try: | 
					
						
							|  |  |  |  |             model_abs_path.unlink() | 
					
						
							|  |  |  |  |             extra_msg = f"(已删除文件)" | 
					
						
							|  |  |  |  |         except Exception as e: | 
					
						
							|  |  |  |  |             extra_msg = f"(文件删除失败:{str(e)})" | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  |         # 如果删除的是当前加载的模型,重置缓存 | 
					
						
							|  |  |  |  |         global _yolo_model, _current_model_version | 
					
						
							|  |  |  |  |         if _yolo_model and str(getattr(_yolo_model, "model_path", "")) == model_abs_path_str: | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  |             _yolo_model = None | 
					
						
							| 
									
										
										
										
											2025-09-12 18:28:43 +08:00
										 |  |  |  |             _current_model_version = None | 
					
						
							|  |  |  |  |             print(f"[模型删除] 已清空全局模型缓存(路径:{model_abs_path_str})") | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |         return APIResponse( | 
					
						
							|  |  |  |  |             code=200, | 
					
						
							|  |  |  |  |             message=f"模型删除成功!ID:{model_id} {extra_msg}", | 
					
						
							|  |  |  |  |             data=None | 
					
						
							|  |  |  |  |         ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     except MySQLError as e: | 
					
						
							|  |  |  |  |         if conn: | 
					
						
							|  |  |  |  |             conn.rollback() | 
					
						
							|  |  |  |  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | 
					
						
							|  |  |  |  |     finally: | 
					
						
							|  |  |  |  |         db.close_connection(conn, cursor) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  | # 7. 下载模型文件(保持不变) | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | @router.get("/{model_id}/download", summary="下载模型文件") | 
					
						
							| 
									
										
										
										
											2025-09-15 18:35:43 +08:00
										 |  |  |  | @encrypt_response() | 
					
						
							| 
									
										
										
										
											2025-09-12 14:05:09 +08:00
										 |  |  |  | async def download_model(model_id: int): | 
					
						
							|  |  |  |  |     conn = None | 
					
						
							|  |  |  |  |     cursor = None | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         conn = db.get_connection() | 
					
						
							|  |  |  |  |         cursor = conn.cursor(dictionary=True) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | 
					
						
							|  |  |  |  |         model = cursor.fetchone() | 
					
						
							|  |  |  |  |         if not model: | 
					
						
							|  |  |  |  |             raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         valid_abs_path = get_valid_model_abs_path(model["path"]) | 
					
						
							|  |  |  |  |         model_abs_path = Path(valid_abs_path) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         return FileResponse( | 
					
						
							|  |  |  |  |             path=model_abs_path, | 
					
						
							|  |  |  |  |             filename=f"model_{model_id}_{model['name']}.pt", | 
					
						
							|  |  |  |  |             media_type="application/octet-stream" | 
					
						
							|  |  |  |  |         ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     except MySQLError as e: | 
					
						
							|  |  |  |  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | 
					
						
							|  |  |  |  |     finally: | 
					
						
							| 
									
										
										
										
											2025-09-15 17:43:36 +08:00
										 |  |  |  |         db.close_connection(conn, cursor) |