270 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			270 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | import os | |||
|  | from pathlib import Path | |||
|  | from service.file_service import save_source_file | |||
|  | 
 | |||
|  | from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query | |||
|  | from mysql.connector import Error as MySQLError | |||
|  | 
 | |||
|  | from ds.db import db | |||
|  | from encryption.encrypt_decorator import encrypt_response | |||
|  | from schema.model_schema import ( | |||
|  |     ModelResponse, | |||
|  |     ModelListResponse | |||
|  | ) | |||
|  | from schema.response_schema import APIResponse | |||
|  | from service.model_service import ALLOWED_MODEL_EXT, MAX_MODEL_SIZE, load_yolo_model | |||
|  | 
 | |||
|  | router = APIRouter(prefix="/api/models", tags=["模型管理"]) | |||
|  | 
 | |||
|  | 
 | |||
|  | # 上传模型 | |||
|  | @router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)") | |||
|  | @encrypt_response() | |||
|  | async def upload_model( | |||
|  |         name: str = Form(..., description="模型名称"), | |||
|  |         description: str = Form(None, description="模型描述"), | |||
|  |         file: UploadFile = File(..., description=f"YOLO模型文件(.pt、最大{MAX_MODEL_SIZE // 1024 // 1024}MB)") | |||
|  | ): | |||
|  |     conn = None | |||
|  |     cursor = None | |||
|  |     try: | |||
|  |         # 校验文件格式 | |||
|  |         file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else "" | |||
|  |         if file_ext not in ALLOWED_MODEL_EXT: | |||
|  |             raise HTTPException( | |||
|  |                 status_code=400, | |||
|  |                 detail=f"仅支持{ALLOWED_MODEL_EXT}格式、当前:{file_ext}" | |||
|  |             ) | |||
|  | 
 | |||
|  |         # 校验文件大小 | |||
|  |         if file.size > MAX_MODEL_SIZE: | |||
|  |             raise HTTPException( | |||
|  |                 status_code=400, | |||
|  |                 detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB、当前{file.size // 1024 // 1024}MB" | |||
|  |             ) | |||
|  |         # 保存文件 | |||
|  |         file_path = save_source_file(file, "model") | |||
|  | 
 | |||
|  |         # 数据库操作 | |||
|  |         conn = db.get_connection() | |||
|  |         cursor = conn.cursor(dictionary=True) | |||
|  | 
 | |||
|  |         insert_sql = """
 | |||
|  |             INSERT INTO model (name, path, is_default, description, file_size) | |||
|  |             VALUES (%s, %s, 0, %s, %s) | |||
|  |         """
 | |||
|  |         cursor.execute(insert_sql, (name, file_path, description, file.size)) | |||
|  |         conn.commit() | |||
|  | 
 | |||
|  |         # 获取新增记录 | |||
|  |         cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()") | |||
|  |         new_model = cursor.fetchone() | |||
|  |         if not new_model: | |||
|  |             raise HTTPException(status_code=500, detail="上传成功但无法获取记录") | |||
|  | 
 | |||
|  |         return APIResponse( | |||
|  |             code=200, | |||
|  |             message=f"模型上传成功", | |||
|  |             data=ModelResponse(**new_model) | |||
|  |         ) | |||
|  | 
 | |||
|  |     except MySQLError as e: | |||
|  |         if conn: | |||
|  |             conn.rollback() | |||
|  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | |||
|  |     except Exception as e: | |||
|  |         raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e | |||
|  |     finally: | |||
|  |         await file.close() | |||
|  |         db.close_connection(conn, cursor) | |||
|  | 
 | |||
|  | 
 | |||
|  | # 获取模型列表 | |||
|  | @router.get("", response_model=APIResponse, summary="获取模型列表(分页)") | |||
|  | @encrypt_response() | |||
|  | async def get_model_list( | |||
|  |         page: int = Query(1, ge=1), | |||
|  |         page_size: int = Query(10, ge=1, le=100), | |||
|  |         name: str = Query(None), | |||
|  |         is_default: bool = Query(None) | |||
|  | ): | |||
|  |     conn = None | |||
|  |     cursor = None | |||
|  |     try: | |||
|  |         conn = db.get_connection() | |||
|  |         cursor = conn.cursor(dictionary=True) | |||
|  | 
 | |||
|  |         where_clause = [] | |||
|  |         params = [] | |||
|  |         if name: | |||
|  |             where_clause.append("name LIKE %s") | |||
|  |             params.append(f"%{name}%") | |||
|  |         if is_default is not None: | |||
|  |             where_clause.append("is_default = %s") | |||
|  |             params.append(1 if is_default else 0) | |||
|  | 
 | |||
|  |         # 总记录数 | |||
|  |         count_sql = "SELECT COUNT(*) AS total FROM model" | |||
|  |         if where_clause: | |||
|  |             count_sql += " WHERE " + " AND ".join(where_clause) | |||
|  |         cursor.execute(count_sql, params) | |||
|  |         total = cursor.fetchone()["total"] | |||
|  | 
 | |||
|  |         # 分页数据 | |||
|  |         offset = (page - 1) * page_size | |||
|  |         list_sql = "SELECT * FROM model" | |||
|  |         if where_clause: | |||
|  |             list_sql += " WHERE " + " AND ".join(where_clause) | |||
|  |         list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s" | |||
|  |         params.extend([page_size, offset]) | |||
|  | 
 | |||
|  |         cursor.execute(list_sql, params) | |||
|  |         model_list = cursor.fetchall() | |||
|  | 
 | |||
|  |         return APIResponse( | |||
|  |             code=200, | |||
|  |             message=f"获取成功!", | |||
|  |             data=ModelListResponse( | |||
|  |                 total=total, | |||
|  |                 models=[ModelResponse(**model) for model in model_list] | |||
|  |             ) | |||
|  |         ) | |||
|  | 
 | |||
|  |     except MySQLError as e: | |||
|  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | |||
|  |     finally: | |||
|  |         db.close_connection(conn, cursor) | |||
|  | 
 | |||
|  | 
 | |||
|  | # 更换默认模型 | |||
|  | @router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型") | |||
|  | @encrypt_response() | |||
|  | async def set_default_model( | |||
|  |         model_id: int | |||
|  | ): | |||
|  |     conn = None | |||
|  |     cursor = None | |||
|  |     try: | |||
|  |         conn = db.get_connection() | |||
|  |         cursor = conn.cursor(dictionary=True) | |||
|  |         conn.autocommit = False | |||
|  | 
 | |||
|  |         # 校验目标模型是否存在 | |||
|  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | |||
|  |         target_model = cursor.fetchone() | |||
|  |         if not target_model: | |||
|  |             raise HTTPException(status_code=404, detail=f"目标模型不存在!") | |||
|  | 
 | |||
|  |         # 检查是否已为默认模型 | |||
|  |         if target_model["is_default"]: | |||
|  |             return APIResponse( | |||
|  |                 code=200, | |||
|  |                 message=f"已是默认模型、无需更换", | |||
|  |                 data=ModelResponse(**target_model) | |||
|  |             ) | |||
|  | 
 | |||
|  |         # 数据库事务:更新默认模型状态 | |||
|  |         try: | |||
|  |             cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP") | |||
|  |             cursor.execute( | |||
|  |                 "UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s", | |||
|  |                 (model_id,) | |||
|  |             ) | |||
|  |             conn.commit() | |||
|  |         except MySQLError as e: | |||
|  |             conn.rollback() | |||
|  |             raise HTTPException( | |||
|  |                 status_code=500, | |||
|  |                 detail=f"更新默认模型状态失败(已回滚):{str(e)}" | |||
|  |             ) from e | |||
|  | 
 | |||
|  |         # 更新模型 | |||
|  |         load_yolo_model() | |||
|  |         # 返回成功响应 | |||
|  |         return APIResponse( | |||
|  |             code=200, | |||
|  |             message=f"更换成功", | |||
|  |             data=None | |||
|  |         ) | |||
|  | 
 | |||
|  |     except MySQLError as e: | |||
|  |         if conn: | |||
|  |             conn.rollback() | |||
|  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | |||
|  |     finally: | |||
|  |         if conn: | |||
|  |             conn.autocommit = True | |||
|  |         db.close_connection(conn, cursor) | |||
|  | 
 | |||
|  | 
 | |||
|  | # 路由文件(如 model_router.py)中的删除接口 | |||
|  | @router.delete("/{model_id}", response_model=APIResponse, summary="删除模型") | |||
|  | @encrypt_response() | |||
|  | async def delete_model(model_id: int): | |||
|  |     # 1. 正确导入 model_service 中的全局变量(关键修复:变量名匹配) | |||
|  |     from service.model_service import ( | |||
|  |         current_yolo_model, | |||
|  |         current_model_absolute_path, | |||
|  |         load_yolo_model  # 用于删除后重新加载模型(可选) | |||
|  |     ) | |||
|  | 
 | |||
|  |     conn = None | |||
|  |     cursor = None | |||
|  |     try: | |||
|  |         conn = db.get_connection() | |||
|  |         cursor = conn.cursor(dictionary=True) | |||
|  | 
 | |||
|  |         # 2. 查询待删除模型信息 | |||
|  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | |||
|  |         exist_model = cursor.fetchone() | |||
|  |         if not exist_model: | |||
|  |             raise HTTPException(status_code=404, detail=f"模型不存在!") | |||
|  | 
 | |||
|  |         # 3. 关键判断:①默认模型不可删 ②正在使用的模型不可删 | |||
|  |         if exist_model["is_default"]: | |||
|  |             raise HTTPException(status_code=400, detail="默认模型不可删除!") | |||
|  | 
 | |||
|  |         # 计算待删除模型的绝对路径(与 model_service 逻辑一致) | |||
|  |         from service.file_service import get_absolute_path | |||
|  |         del_model_abs_path = get_absolute_path(exist_model["path"]) | |||
|  | 
 | |||
|  |         # 判断是否正在使用(对比 current_model_absolute_path) | |||
|  |         if current_model_absolute_path and del_model_abs_path == current_model_absolute_path: | |||
|  |             raise HTTPException(status_code=400, detail="该模型正在使用中,禁止删除!") | |||
|  | 
 | |||
|  |         # 4. 先删除数据库记录(避免文件删除失败导致数据不一致) | |||
|  |         cursor.execute("DELETE FROM model WHERE id = %s", (model_id,)) | |||
|  |         conn.commit() | |||
|  | 
 | |||
|  |         # 5. 再删除本地文件(捕获文件删除异常,不影响数据库删除结果) | |||
|  |         extra_msg = "" | |||
|  |         try: | |||
|  |             if os.path.exists(del_model_abs_path): | |||
|  |                 os.remove(del_model_abs_path)  # 或用 Path(del_model_abs_path).unlink() | |||
|  |                 extra_msg = "(本地文件已同步删除)" | |||
|  |             else: | |||
|  |                 extra_msg = "(本地文件不存在,无需删除)" | |||
|  |         except Exception as e: | |||
|  |             extra_msg = f"(本地文件删除失败:{str(e)})" | |||
|  | 
 | |||
|  |         # 6. 若删除后当前模型为空(极端情况),重新加载默认模型(可选优化) | |||
|  |         if current_yolo_model is None: | |||
|  |             try: | |||
|  |                 load_yolo_model() | |||
|  |                 print(f"[模型删除后] 重新加载默认模型成功") | |||
|  |             except Exception as e: | |||
|  |                 print(f"[模型删除后] 重新加载默认模型失败:{str(e)}") | |||
|  | 
 | |||
|  |         return APIResponse( | |||
|  |             code=200, | |||
|  |             message=f"模型删除成功!", | |||
|  |             data=None | |||
|  |         ) | |||
|  | 
 | |||
|  |     except MySQLError as e: | |||
|  |         if conn: | |||
|  |             conn.rollback() | |||
|  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | |||
|  |     finally: | |||
|  |         db.close_connection(conn, cursor) |