270 lines
9.3 KiB
Python
270 lines
9.3 KiB
Python
import os
|
||
from pathlib import Path
|
||
from service.file_service import save_source_file
|
||
|
||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
|
||
from mysql.connector import Error as MySQLError
|
||
|
||
from ds.db import db
|
||
from encryption.encrypt_decorator import encrypt_response
|
||
from schema.model_schema import (
|
||
ModelResponse,
|
||
ModelListResponse
|
||
)
|
||
from schema.response_schema import APIResponse
|
||
from service.model_service import ALLOWED_MODEL_EXT, MAX_MODEL_SIZE, load_yolo_model
|
||
|
||
router = APIRouter(prefix="/api/models", tags=["模型管理"])
|
||
|
||
|
||
# 上传模型
|
||
@router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)")
|
||
@encrypt_response()
|
||
async def upload_model(
|
||
name: str = Form(..., description="模型名称"),
|
||
description: str = Form(None, description="模型描述"),
|
||
file: UploadFile = File(..., description=f"YOLO模型文件(.pt、最大{MAX_MODEL_SIZE // 1024 // 1024}MB)")
|
||
):
|
||
conn = None
|
||
cursor = None
|
||
try:
|
||
# 校验文件格式
|
||
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else ""
|
||
if file_ext not in ALLOWED_MODEL_EXT:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"仅支持{ALLOWED_MODEL_EXT}格式、当前:{file_ext}"
|
||
)
|
||
|
||
# 校验文件大小
|
||
if file.size > MAX_MODEL_SIZE:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB、当前{file.size // 1024 // 1024}MB"
|
||
)
|
||
# 保存文件
|
||
file_path = save_source_file(file, "model")
|
||
|
||
# 数据库操作
|
||
conn = db.get_connection()
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
insert_sql = """
|
||
INSERT INTO model (name, path, is_default, description, file_size)
|
||
VALUES (%s, %s, 0, %s, %s)
|
||
"""
|
||
cursor.execute(insert_sql, (name, file_path, description, file.size))
|
||
conn.commit()
|
||
|
||
# 获取新增记录
|
||
cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()")
|
||
new_model = cursor.fetchone()
|
||
if not new_model:
|
||
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
|
||
|
||
return APIResponse(
|
||
code=200,
|
||
message=f"模型上传成功",
|
||
data=ModelResponse(**new_model)
|
||
)
|
||
|
||
except MySQLError as e:
|
||
if conn:
|
||
conn.rollback()
|
||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e
|
||
finally:
|
||
await file.close()
|
||
db.close_connection(conn, cursor)
|
||
|
||
|
||
# 获取模型列表
|
||
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
|
||
@encrypt_response()
|
||
async def get_model_list(
|
||
page: int = Query(1, ge=1),
|
||
page_size: int = Query(10, ge=1, le=100),
|
||
name: str = Query(None),
|
||
is_default: bool = Query(None)
|
||
):
|
||
conn = None
|
||
cursor = None
|
||
try:
|
||
conn = db.get_connection()
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
where_clause = []
|
||
params = []
|
||
if name:
|
||
where_clause.append("name LIKE %s")
|
||
params.append(f"%{name}%")
|
||
if is_default is not None:
|
||
where_clause.append("is_default = %s")
|
||
params.append(1 if is_default else 0)
|
||
|
||
# 总记录数
|
||
count_sql = "SELECT COUNT(*) AS total FROM model"
|
||
if where_clause:
|
||
count_sql += " WHERE " + " AND ".join(where_clause)
|
||
cursor.execute(count_sql, params)
|
||
total = cursor.fetchone()["total"]
|
||
|
||
# 分页数据
|
||
offset = (page - 1) * page_size
|
||
list_sql = "SELECT * FROM model"
|
||
if where_clause:
|
||
list_sql += " WHERE " + " AND ".join(where_clause)
|
||
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
|
||
params.extend([page_size, offset])
|
||
|
||
cursor.execute(list_sql, params)
|
||
model_list = cursor.fetchall()
|
||
|
||
return APIResponse(
|
||
code=200,
|
||
message=f"获取成功!",
|
||
data=ModelListResponse(
|
||
total=total,
|
||
models=[ModelResponse(**model) for model in model_list]
|
||
)
|
||
)
|
||
|
||
except MySQLError as e:
|
||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||
finally:
|
||
db.close_connection(conn, cursor)
|
||
|
||
|
||
# 更换默认模型
|
||
@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型")
|
||
@encrypt_response()
|
||
async def set_default_model(
|
||
model_id: int
|
||
):
|
||
conn = None
|
||
cursor = None
|
||
try:
|
||
conn = db.get_connection()
|
||
cursor = conn.cursor(dictionary=True)
|
||
conn.autocommit = False
|
||
|
||
# 校验目标模型是否存在
|
||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||
target_model = cursor.fetchone()
|
||
if not target_model:
|
||
raise HTTPException(status_code=404, detail=f"目标模型不存在!")
|
||
|
||
# 检查是否已为默认模型
|
||
if target_model["is_default"]:
|
||
return APIResponse(
|
||
code=200,
|
||
message=f"已是默认模型、无需更换",
|
||
data=ModelResponse(**target_model)
|
||
)
|
||
|
||
# 数据库事务:更新默认模型状态
|
||
try:
|
||
cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP")
|
||
cursor.execute(
|
||
"UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s",
|
||
(model_id,)
|
||
)
|
||
conn.commit()
|
||
except MySQLError as e:
|
||
conn.rollback()
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"更新默认模型状态失败(已回滚):{str(e)}"
|
||
) from e
|
||
|
||
# 更新模型
|
||
load_yolo_model()
|
||
# 返回成功响应
|
||
return APIResponse(
|
||
code=200,
|
||
message=f"更换成功",
|
||
data=None
|
||
)
|
||
|
||
except MySQLError as e:
|
||
if conn:
|
||
conn.rollback()
|
||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||
finally:
|
||
if conn:
|
||
conn.autocommit = True
|
||
db.close_connection(conn, cursor)
|
||
|
||
|
||
# 路由文件(如 model_router.py)中的删除接口
|
||
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
|
||
@encrypt_response()
|
||
async def delete_model(model_id: int):
|
||
# 1. 正确导入 model_service 中的全局变量(关键修复:变量名匹配)
|
||
from service.model_service import (
|
||
current_yolo_model,
|
||
current_model_absolute_path,
|
||
load_yolo_model # 用于删除后重新加载模型(可选)
|
||
)
|
||
|
||
conn = None
|
||
cursor = None
|
||
try:
|
||
conn = db.get_connection()
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
# 2. 查询待删除模型信息
|
||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||
exist_model = cursor.fetchone()
|
||
if not exist_model:
|
||
raise HTTPException(status_code=404, detail=f"模型不存在!")
|
||
|
||
# 3. 关键判断:①默认模型不可删 ②正在使用的模型不可删
|
||
if exist_model["is_default"]:
|
||
raise HTTPException(status_code=400, detail="默认模型不可删除!")
|
||
|
||
# 计算待删除模型的绝对路径(与 model_service 逻辑一致)
|
||
from service.file_service import get_absolute_path
|
||
del_model_abs_path = get_absolute_path(exist_model["path"])
|
||
|
||
# 判断是否正在使用(对比 current_model_absolute_path)
|
||
if current_model_absolute_path and del_model_abs_path == current_model_absolute_path:
|
||
raise HTTPException(status_code=400, detail="该模型正在使用中,禁止删除!")
|
||
|
||
# 4. 先删除数据库记录(避免文件删除失败导致数据不一致)
|
||
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
|
||
conn.commit()
|
||
|
||
# 5. 再删除本地文件(捕获文件删除异常,不影响数据库删除结果)
|
||
extra_msg = ""
|
||
try:
|
||
if os.path.exists(del_model_abs_path):
|
||
os.remove(del_model_abs_path) # 或用 Path(del_model_abs_path).unlink()
|
||
extra_msg = "(本地文件已同步删除)"
|
||
else:
|
||
extra_msg = "(本地文件不存在,无需删除)"
|
||
except Exception as e:
|
||
extra_msg = f"(本地文件删除失败:{str(e)})"
|
||
|
||
# 6. 若删除后当前模型为空(极端情况),重新加载默认模型(可选优化)
|
||
if current_yolo_model is None:
|
||
try:
|
||
load_yolo_model()
|
||
print(f"[模型删除后] 重新加载默认模型成功")
|
||
except Exception as e:
|
||
print(f"[模型删除后] 重新加载默认模型失败:{str(e)}")
|
||
|
||
return APIResponse(
|
||
code=200,
|
||
message=f"模型删除成功!",
|
||
data=None
|
||
)
|
||
|
||
except MySQLError as e:
|
||
if conn:
|
||
conn.rollback()
|
||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||
finally:
|
||
db.close_connection(conn, cursor)
|