Files
video_detect/router/model_router.py
2025-09-30 17:17:20 +08:00

270 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from pathlib import Path
from service.file_service import save_source_file
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
from mysql.connector import Error as MySQLError
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.model_schema import (
ModelResponse,
ModelListResponse
)
from schema.response_schema import APIResponse
from service.model_service import ALLOWED_MODEL_EXT, MAX_MODEL_SIZE, load_yolo_model
router = APIRouter(prefix="/api/models", tags=["模型管理"])
# 上传模型
@router.post("", response_model=APIResponse, summary="上传YOLO模型.pt格式")
@encrypt_response()
async def upload_model(
name: str = Form(..., description="模型名称"),
description: str = Form(None, description="模型描述"),
file: UploadFile = File(..., description=f"YOLO模型文件.pt、最大{MAX_MODEL_SIZE // 1024 // 1024}MB")
):
conn = None
cursor = None
try:
# 校验文件格式
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else ""
if file_ext not in ALLOWED_MODEL_EXT:
raise HTTPException(
status_code=400,
detail=f"仅支持{ALLOWED_MODEL_EXT}格式、当前:{file_ext}"
)
# 校验文件大小
if file.size > MAX_MODEL_SIZE:
raise HTTPException(
status_code=400,
detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB、当前{file.size // 1024 // 1024}MB"
)
# 保存文件
file_path = save_source_file(file, "model")
# 数据库操作
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
insert_sql = """
INSERT INTO model (name, path, is_default, description, file_size)
VALUES (%s, %s, 0, %s, %s)
"""
cursor.execute(insert_sql, (name, file_path, description, file.size))
conn.commit()
# 获取新增记录
cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()")
new_model = cursor.fetchone()
if not new_model:
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
return APIResponse(
code=200,
message=f"模型上传成功",
data=ModelResponse(**new_model)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
except Exception as e:
raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e
finally:
await file.close()
db.close_connection(conn, cursor)
# 获取模型列表
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
@encrypt_response()
async def get_model_list(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: str = Query(None),
is_default: bool = Query(None)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
where_clause = []
params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%")
if is_default is not None:
where_clause.append("is_default = %s")
params.append(1 if is_default else 0)
# 总记录数
count_sql = "SELECT COUNT(*) AS total FROM model"
if where_clause:
count_sql += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_sql, params)
total = cursor.fetchone()["total"]
# 分页数据
offset = (page - 1) * page_size
list_sql = "SELECT * FROM model"
if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause)
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_sql, params)
model_list = cursor.fetchall()
return APIResponse(
code=200,
message=f"获取成功!",
data=ModelListResponse(
total=total,
models=[ModelResponse(**model) for model in model_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 更换默认模型
@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型")
@encrypt_response()
async def set_default_model(
model_id: int
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
conn.autocommit = False
# 校验目标模型是否存在
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
target_model = cursor.fetchone()
if not target_model:
raise HTTPException(status_code=404, detail=f"目标模型不存在!")
# 检查是否已为默认模型
if target_model["is_default"]:
return APIResponse(
code=200,
message=f"已是默认模型、无需更换",
data=ModelResponse(**target_model)
)
# 数据库事务:更新默认模型状态
try:
cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP")
cursor.execute(
"UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s",
(model_id,)
)
conn.commit()
except MySQLError as e:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"更新默认模型状态失败(已回滚):{str(e)}"
) from e
# 更新模型
load_yolo_model()
# 返回成功响应
return APIResponse(
code=200,
message=f"更换成功",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
if conn:
conn.autocommit = True
db.close_connection(conn, cursor)
# 路由文件(如 model_router.py中的删除接口
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
@encrypt_response()
async def delete_model(model_id: int):
# 1. 正确导入 model_service 中的全局变量(关键修复:变量名匹配)
from service.model_service import (
current_yolo_model,
current_model_absolute_path,
load_yolo_model # 用于删除后重新加载模型(可选)
)
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 2. 查询待删除模型信息
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
exist_model = cursor.fetchone()
if not exist_model:
raise HTTPException(status_code=404, detail=f"模型不存在!")
# 3. 关键判断:①默认模型不可删 ②正在使用的模型不可删
if exist_model["is_default"]:
raise HTTPException(status_code=400, detail="默认模型不可删除!")
# 计算待删除模型的绝对路径(与 model_service 逻辑一致)
from service.file_service import get_absolute_path
del_model_abs_path = get_absolute_path(exist_model["path"])
# 判断是否正在使用(对比 current_model_absolute_path
if current_model_absolute_path and del_model_abs_path == current_model_absolute_path:
raise HTTPException(status_code=400, detail="该模型正在使用中,禁止删除!")
# 4. 先删除数据库记录(避免文件删除失败导致数据不一致)
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
conn.commit()
# 5. 再删除本地文件(捕获文件删除异常,不影响数据库删除结果)
extra_msg = ""
try:
if os.path.exists(del_model_abs_path):
os.remove(del_model_abs_path) # 或用 Path(del_model_abs_path).unlink()
extra_msg = "(本地文件已同步删除)"
else:
extra_msg = "(本地文件不存在,无需删除)"
except Exception as e:
extra_msg = f"(本地文件删除失败:{str(e)}"
# 6. 若删除后当前模型为空(极端情况),重新加载默认模型(可选优化)
if current_yolo_model is None:
try:
load_yolo_model()
print(f"[模型删除后] 重新加载默认模型成功")
except Exception as e:
print(f"[模型删除后] 重新加载默认模型失败:{str(e)}")
return APIResponse(
code=200,
message=f"模型删除成功!",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)