Files
video/service/model_service.py

498 lines
17 KiB
Python
Raw Normal View History

from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
from fastapi.responses import FileResponse
from mysql.connector import Error as MySQLError
import os
import shutil
from pathlib import Path
from datetime import datetime
# 复用项目依赖
from ds.db import db
from schema.model_schema import (
ModelCreateRequest,
ModelUpdateRequest,
ModelResponse,
ModelListResponse
)
from schema.response_schema import APIResponse
from util.model_util import load_yolo_model # 使用修复后的模型加载工具
# 路径配置
CURRENT_FILE_PATH = Path(__file__).resolve()
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent
MODEL_SAVE_ROOT = PROJECT_ROOT / "resource" / "models"
MODEL_SAVE_ROOT.mkdir(exist_ok=True, parents=True)
DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep
# 模型限制
ALLOWED_MODEL_EXT = {"pt"}
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
# 全局模型变量
global _yolo_model
_yolo_model = None
router = APIRouter(prefix="/models", tags=["模型管理"])
# 工具函数:验证模型路径
def get_valid_model_abs_path(relative_path: str) -> str:
try:
relative_path = relative_path.replace("/", os.sep)
model_abs_path = PROJECT_ROOT / relative_path
model_abs_path = model_abs_path.resolve()
model_abs_path_str = str(model_abs_path)
if not model_abs_path_str.startswith(str(MODEL_SAVE_ROOT)):
raise HTTPException(
status_code=400,
detail=f"模型路径非法!允许目录:{str(MODEL_SAVE_ROOT)},当前路径:{model_abs_path_str}"
)
if not model_abs_path.exists():
raise HTTPException(
status_code=404,
detail=f"模型文件不存在!路径:{model_abs_path_str}"
)
if not model_abs_path.is_file():
raise HTTPException(
status_code=400,
detail=f"路径不是文件!路径:{model_abs_path_str}"
)
file_size = model_abs_path.stat().st_size
if file_size > MAX_MODEL_SIZE:
raise HTTPException(
status_code=400,
detail=f"模型文件过大({file_size // 1024 // 1024}MB超过限制{MAX_MODEL_SIZE // 1024 // 1024}MB"
)
file_ext = model_abs_path.suffix.lower()
if file_ext not in [f".{ext}" for ext in ALLOWED_MODEL_EXT]:
raise HTTPException(
status_code=400,
detail=f"模型格式非法(仅支持{ALLOWED_MODEL_EXT})!当前格式:{file_ext}"
)
print(f"[模型路径校验] 成功!路径:{model_abs_path_str},大小:{file_size // 1024}KB")
return model_abs_path_str
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"路径处理失败:{str(e)}"
) from e
# 1. 上传模型
@router.post("", response_model=APIResponse, summary="上传YOLO模型.pt格式")
async def upload_model(
name: str = Form(..., description="模型名称"),
description: str = Form(None, description="模型描述"),
is_default: bool = Form(False, description="是否设为默认模型"),
file: UploadFile = File(..., description=f"YOLO模型文件.pt最大{MAX_MODEL_SIZE // 1024 // 1024}MB")
):
conn = None
cursor = None
saved_file_path = None
try:
# 校验文件
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else ""
if file_ext not in ALLOWED_MODEL_EXT:
raise HTTPException(
status_code=400,
detail=f"仅支持{ALLOWED_MODEL_EXT}格式,当前:{file_ext}"
)
if file.size > MAX_MODEL_SIZE:
raise HTTPException(
status_code=400,
detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB当前{file.size // 1024 // 1024}MB"
)
# 保存文件
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
safe_filename = f"model_{timestamp}_{file.filename.replace(' ', '_')}"
saved_file_path = MODEL_SAVE_ROOT / safe_filename
with open(saved_file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
saved_file_path.chmod(0o644) # 设置权限
# 数据库路径处理
db_relative_path = str(saved_file_path).replace(DB_PATH_PREFIX_TO_REMOVE, "").replace(os.sep, "/")
# 数据库操作
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
if is_default:
cursor.execute("UPDATE model SET is_default = 0")
insert_sql = """
INSERT INTO model (name, path, is_default, description, file_size)
VALUES (%s, %s, %s, %s, %s)
"""
cursor.execute(insert_sql, (name, db_relative_path, 1 if is_default else 0, description, file.size))
conn.commit()
cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()")
new_model = cursor.fetchone()
if not new_model:
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
# 加载默认模型
global _yolo_model
if is_default:
valid_abs_path = get_valid_model_abs_path(db_relative_path)
_yolo_model = load_yolo_model(valid_abs_path)
if not _yolo_model:
raise HTTPException(
status_code=500,
detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path}"
)
return APIResponse(
code=201,
message=f"模型上传成功ID{new_model['id']}",
data=ModelResponse(**new_model)
)
except MySQLError as e:
if conn:
conn.rollback()
if saved_file_path and saved_file_path.exists():
saved_file_path.unlink()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
except Exception as e:
if saved_file_path and saved_file_path.exists():
saved_file_path.unlink()
raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e
finally:
await file.close()
db.close_connection(conn, cursor)
# 2. 获取模型列表
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
async def get_model_list(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: str = Query(None),
is_default: bool = Query(None)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
where_clause = []
params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%")
if is_default is not None:
where_clause.append("is_default = %s")
params.append(1 if is_default else 0)
# 总记录数
count_sql = "SELECT COUNT(*) AS total FROM model"
if where_clause:
count_sql += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_sql, params)
total = cursor.fetchone()["total"]
# 分页数据
offset = (page - 1) * page_size
list_sql = "SELECT * FROM model"
if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause)
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_sql, params)
model_list = cursor.fetchall()
return APIResponse(
code=200,
message=f"获取成功!共{total}条记录",
data=ModelListResponse(
total=total,
models=[ModelResponse(**model) for model in model_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 3. 获取默认模型
@router.get("/default", response_model=APIResponse, summary="获取当前默认模型")
async def get_default_model():
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE is_default = 1")
default_model = cursor.fetchone()
if not default_model:
raise HTTPException(status_code=404, detail="暂无默认模型")
valid_abs_path = get_valid_model_abs_path(default_model["path"])
global _yolo_model
if not _yolo_model:
_yolo_model = load_yolo_model(valid_abs_path)
if not _yolo_model:
raise HTTPException(
status_code=500,
detail=f"默认模型存在,但加载失败(路径:{valid_abs_path}"
)
return APIResponse(
code=200,
message="默认模型查询成功",
data=ModelResponse(**default_model)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 4. 获取单个模型详情
@router.get("/{model_id}", response_model=APIResponse, summary="获取单个模型详情")
async def get_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
model = cursor.fetchone()
if not model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
try:
model_abs_path = get_valid_model_abs_path(model["path"])
except HTTPException as e:
return APIResponse(
code=200,
message=f"查询成功,但路径异常:{e.detail}",
data=ModelResponse(**model)
)
return APIResponse(
code=200,
message="查询成功",
data=ModelResponse(**model)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 5. 更新模型信息
@router.put("/{model_id}", response_model=APIResponse, summary="更新模型信息")
async def update_model(model_id: int, model_update: ModelUpdateRequest):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
exist_model = cursor.fetchone()
if not exist_model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
update_fields = []
params = []
if model_update.name is not None:
update_fields.append("name = %s")
params.append(model_update.name)
if model_update.description is not None:
update_fields.append("description = %s")
params.append(model_update.description)
need_load_default = False
if model_update.is_default is not None:
if model_update.is_default:
cursor.execute("UPDATE model SET is_default = 0")
update_fields.append("is_default = 1")
need_load_default = True
else:
cursor.execute("SELECT COUNT(*) AS cnt FROM model WHERE is_default = 1")
default_count = cursor.fetchone()["cnt"]
if default_count == 1 and exist_model["is_default"]:
raise HTTPException(
status_code=400,
detail="当前是唯一默认模型,不可取消!"
)
update_fields.append("is_default = 0")
if not update_fields:
raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
params.append(model_id)
update_sql = f"""
UPDATE model
SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP
WHERE id = %s
"""
cursor.execute(update_sql, params)
conn.commit()
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
updated_model = cursor.fetchone()
global _yolo_model
if need_load_default:
valid_abs_path = get_valid_model_abs_path(updated_model["path"])
_yolo_model = load_yolo_model(valid_abs_path)
if not _yolo_model:
raise HTTPException(
status_code=500,
detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path}"
)
return APIResponse(
code=200,
message="模型更新成功",
data=ModelResponse(**updated_model)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 6. 删除模型
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
async def delete_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
exist_model = cursor.fetchone()
if not exist_model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
if exist_model["is_default"]:
raise HTTPException(status_code=400, detail="默认模型不可删除!")
try:
model_abs_path_str = get_valid_model_abs_path(exist_model["path"])
model_abs_path = Path(model_abs_path_str)
except HTTPException as e:
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
conn.commit()
return APIResponse(
code=200,
message=f"记录删除成功,文件异常:{e.detail}",
data=None
)
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
conn.commit()
extra_msg = ""
try:
model_abs_path.unlink()
extra_msg = f"(已删除文件)"
except Exception as e:
extra_msg = f"(文件删除失败:{str(e)}"
global _yolo_model
if _yolo_model and str(_yolo_model.model_path) == model_abs_path_str:
_yolo_model = None
print(f"[模型删除] 已清空全局模型(路径:{model_abs_path_str}")
return APIResponse(
code=200,
message=f"模型删除成功ID{model_id} {extra_msg}",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 7. 下载模型文件
@router.get("/{model_id}/download", summary="下载模型文件")
async def download_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
model = cursor.fetchone()
if not model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
valid_abs_path = get_valid_model_abs_path(model["path"])
model_abs_path = Path(valid_abs_path)
return FileResponse(
path=model_abs_path,
filename=f"model_{model_id}_{model['name']}.pt",
media_type="application/octet-stream"
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 对外提供当前模型
def get_current_yolo_model():
"""供检测模块获取当前加载的模型"""
global _yolo_model
if not _yolo_model:
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT path FROM model WHERE is_default = 1")
default_model = cursor.fetchone()
if not default_model:
print("[get_current_yolo_model] 暂无默认模型")
return None
valid_abs_path = get_valid_model_abs_path(default_model["path"])
_yolo_model = load_yolo_model(valid_abs_path)
if _yolo_model:
print(f"[get_current_yolo_model] 自动加载默认模型成功")
else:
print(f"[get_current_yolo_model] 自动加载默认模型失败")
except Exception as e:
print(f"[get_current_yolo_model] 加载失败:{str(e)}")
finally:
db.close_connection(conn, cursor)
return _yolo_model