目前可以成功动态更换模型运行的
This commit is contained in:
497
service/model_service.py
Normal file
497
service/model_service.py
Normal file
@ -0,0 +1,497 @@
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from mysql.connector import Error as MySQLError
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# 复用项目依赖
|
||||
from ds.db import db
|
||||
from schema.model_schema import (
|
||||
ModelCreateRequest,
|
||||
ModelUpdateRequest,
|
||||
ModelResponse,
|
||||
ModelListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
from util.model_util import load_yolo_model # 使用修复后的模型加载工具
|
||||
|
||||
# 路径配置
|
||||
CURRENT_FILE_PATH = Path(__file__).resolve()
|
||||
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent
|
||||
MODEL_SAVE_ROOT = PROJECT_ROOT / "resource" / "models"
|
||||
MODEL_SAVE_ROOT.mkdir(exist_ok=True, parents=True)
|
||||
DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep
|
||||
|
||||
# 模型限制
|
||||
ALLOWED_MODEL_EXT = {"pt"}
|
||||
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
|
||||
|
||||
# 全局模型变量
|
||||
global _yolo_model
|
||||
_yolo_model = None
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["模型管理"])
|
||||
|
||||
|
||||
# 工具函数:验证模型路径
|
||||
def get_valid_model_abs_path(relative_path: str) -> str:
|
||||
try:
|
||||
relative_path = relative_path.replace("/", os.sep)
|
||||
model_abs_path = PROJECT_ROOT / relative_path
|
||||
model_abs_path = model_abs_path.resolve()
|
||||
model_abs_path_str = str(model_abs_path)
|
||||
|
||||
if not model_abs_path_str.startswith(str(MODEL_SAVE_ROOT)):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"模型路径非法!允许目录:{str(MODEL_SAVE_ROOT)},当前路径:{model_abs_path_str}"
|
||||
)
|
||||
|
||||
if not model_abs_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"模型文件不存在!路径:{model_abs_path_str}"
|
||||
)
|
||||
|
||||
if not model_abs_path.is_file():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"路径不是文件!路径:{model_abs_path_str}"
|
||||
)
|
||||
|
||||
file_size = model_abs_path.stat().st_size
|
||||
if file_size > MAX_MODEL_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"模型文件过大({file_size // 1024 // 1024}MB),超过限制{MAX_MODEL_SIZE // 1024 // 1024}MB"
|
||||
)
|
||||
|
||||
file_ext = model_abs_path.suffix.lower()
|
||||
if file_ext not in [f".{ext}" for ext in ALLOWED_MODEL_EXT]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"模型格式非法(仅支持{ALLOWED_MODEL_EXT})!当前格式:{file_ext}"
|
||||
)
|
||||
|
||||
print(f"[模型路径校验] 成功!路径:{model_abs_path_str},大小:{file_size // 1024}KB")
|
||||
return model_abs_path_str
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"路径处理失败:{str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
# 1. 上传模型
|
||||
@router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)")
|
||||
async def upload_model(
|
||||
name: str = Form(..., description="模型名称"),
|
||||
description: str = Form(None, description="模型描述"),
|
||||
is_default: bool = Form(False, description="是否设为默认模型"),
|
||||
file: UploadFile = File(..., description=f"YOLO模型文件(.pt,最大{MAX_MODEL_SIZE // 1024 // 1024}MB)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
saved_file_path = None
|
||||
try:
|
||||
# 校验文件
|
||||
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else ""
|
||||
if file_ext not in ALLOWED_MODEL_EXT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"仅支持{ALLOWED_MODEL_EXT}格式,当前:{file_ext}"
|
||||
)
|
||||
if file.size > MAX_MODEL_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB,当前{file.size // 1024 // 1024}MB"
|
||||
)
|
||||
|
||||
# 保存文件
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
safe_filename = f"model_{timestamp}_{file.filename.replace(' ', '_')}"
|
||||
saved_file_path = MODEL_SAVE_ROOT / safe_filename
|
||||
with open(saved_file_path, "wb") as f:
|
||||
shutil.copyfileobj(file.file, f)
|
||||
saved_file_path.chmod(0o644) # 设置权限
|
||||
|
||||
# 数据库路径处理
|
||||
db_relative_path = str(saved_file_path).replace(DB_PATH_PREFIX_TO_REMOVE, "").replace(os.sep, "/")
|
||||
|
||||
# 数据库操作
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
if is_default:
|
||||
cursor.execute("UPDATE model SET is_default = 0")
|
||||
|
||||
insert_sql = """
|
||||
INSERT INTO model (name, path, is_default, description, file_size)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_sql, (name, db_relative_path, 1 if is_default else 0, description, file.size))
|
||||
conn.commit()
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()")
|
||||
new_model = cursor.fetchone()
|
||||
if not new_model:
|
||||
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
|
||||
|
||||
# 加载默认模型
|
||||
global _yolo_model
|
||||
if is_default:
|
||||
valid_abs_path = get_valid_model_abs_path(db_relative_path)
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if not _yolo_model:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path})"
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
code=201,
|
||||
message=f"模型上传成功!ID:{new_model['id']}",
|
||||
data=ModelResponse(**new_model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
if saved_file_path and saved_file_path.exists():
|
||||
saved_file_path.unlink()
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
except Exception as e:
|
||||
if saved_file_path and saved_file_path.exists():
|
||||
saved_file_path.unlink()
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e
|
||||
finally:
|
||||
await file.close()
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 2. 获取模型列表
|
||||
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
|
||||
async def get_model_list(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
name: str = Query(None),
|
||||
is_default: bool = Query(None)
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
where_clause = []
|
||||
params = []
|
||||
if name:
|
||||
where_clause.append("name LIKE %s")
|
||||
params.append(f"%{name}%")
|
||||
if is_default is not None:
|
||||
where_clause.append("is_default = %s")
|
||||
params.append(1 if is_default else 0)
|
||||
|
||||
# 总记录数
|
||||
count_sql = "SELECT COUNT(*) AS total FROM model"
|
||||
if where_clause:
|
||||
count_sql += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_sql, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 分页数据
|
||||
offset = (page - 1) * page_size
|
||||
list_sql = "SELECT * FROM model"
|
||||
if where_clause:
|
||||
list_sql += " WHERE " + " AND ".join(where_clause)
|
||||
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(list_sql, params)
|
||||
model_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"获取成功!共{total}条记录",
|
||||
data=ModelListResponse(
|
||||
total=total,
|
||||
models=[ModelResponse(**model) for model in model_list]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 3. 获取默认模型
|
||||
@router.get("/default", response_model=APIResponse, summary="获取当前默认模型")
|
||||
async def get_default_model():
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE is_default = 1")
|
||||
default_model = cursor.fetchone()
|
||||
|
||||
if not default_model:
|
||||
raise HTTPException(status_code=404, detail="暂无默认模型")
|
||||
|
||||
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
||||
global _yolo_model
|
||||
|
||||
if not _yolo_model:
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if not _yolo_model:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"默认模型存在,但加载失败(路径:{valid_abs_path})"
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="默认模型查询成功",
|
||||
data=ModelResponse(**default_model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 4. 获取单个模型详情
|
||||
@router.get("/{model_id}", response_model=APIResponse, summary="获取单个模型详情")
|
||||
async def get_model(model_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
model = cursor.fetchone()
|
||||
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}")
|
||||
|
||||
try:
|
||||
model_abs_path = get_valid_model_abs_path(model["path"])
|
||||
except HTTPException as e:
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"查询成功,但路径异常:{e.detail}",
|
||||
data=ModelResponse(**model)
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="查询成功",
|
||||
data=ModelResponse(**model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 5. 更新模型信息
|
||||
@router.put("/{model_id}", response_model=APIResponse, summary="更新模型信息")
|
||||
async def update_model(model_id: int, model_update: ModelUpdateRequest):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
exist_model = cursor.fetchone()
|
||||
if not exist_model:
|
||||
raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}")
|
||||
|
||||
update_fields = []
|
||||
params = []
|
||||
if model_update.name is not None:
|
||||
update_fields.append("name = %s")
|
||||
params.append(model_update.name)
|
||||
if model_update.description is not None:
|
||||
update_fields.append("description = %s")
|
||||
params.append(model_update.description)
|
||||
|
||||
need_load_default = False
|
||||
if model_update.is_default is not None:
|
||||
if model_update.is_default:
|
||||
cursor.execute("UPDATE model SET is_default = 0")
|
||||
update_fields.append("is_default = 1")
|
||||
need_load_default = True
|
||||
else:
|
||||
cursor.execute("SELECT COUNT(*) AS cnt FROM model WHERE is_default = 1")
|
||||
default_count = cursor.fetchone()["cnt"]
|
||||
if default_count == 1 and exist_model["is_default"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="当前是唯一默认模型,不可取消!"
|
||||
)
|
||||
update_fields.append("is_default = 0")
|
||||
|
||||
if not update_fields:
|
||||
raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
|
||||
|
||||
params.append(model_id)
|
||||
update_sql = f"""
|
||||
UPDATE model
|
||||
SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = %s
|
||||
"""
|
||||
cursor.execute(update_sql, params)
|
||||
conn.commit()
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
updated_model = cursor.fetchone()
|
||||
|
||||
global _yolo_model
|
||||
if need_load_default:
|
||||
valid_abs_path = get_valid_model_abs_path(updated_model["path"])
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if not _yolo_model:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path})"
|
||||
)
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="模型更新成功",
|
||||
data=ModelResponse(**updated_model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 6. 删除模型
|
||||
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
|
||||
async def delete_model(model_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
exist_model = cursor.fetchone()
|
||||
if not exist_model:
|
||||
raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}")
|
||||
if exist_model["is_default"]:
|
||||
raise HTTPException(status_code=400, detail="默认模型不可删除!")
|
||||
|
||||
try:
|
||||
model_abs_path_str = get_valid_model_abs_path(exist_model["path"])
|
||||
model_abs_path = Path(model_abs_path_str)
|
||||
except HTTPException as e:
|
||||
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
|
||||
conn.commit()
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"记录删除成功,文件异常:{e.detail}",
|
||||
data=None
|
||||
)
|
||||
|
||||
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
|
||||
conn.commit()
|
||||
|
||||
extra_msg = ""
|
||||
try:
|
||||
model_abs_path.unlink()
|
||||
extra_msg = f"(已删除文件)"
|
||||
except Exception as e:
|
||||
extra_msg = f"(文件删除失败:{str(e)})"
|
||||
|
||||
global _yolo_model
|
||||
if _yolo_model and str(_yolo_model.model_path) == model_abs_path_str:
|
||||
_yolo_model = None
|
||||
print(f"[模型删除] 已清空全局模型(路径:{model_abs_path_str})")
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"模型删除成功!ID:{model_id} {extra_msg}",
|
||||
data=None
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 7. 下载模型文件
|
||||
@router.get("/{model_id}/download", summary="下载模型文件")
|
||||
async def download_model(model_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
model = cursor.fetchone()
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail=f"模型不存在!ID:{model_id}")
|
||||
|
||||
valid_abs_path = get_valid_model_abs_path(model["path"])
|
||||
model_abs_path = Path(valid_abs_path)
|
||||
|
||||
return FileResponse(
|
||||
path=model_abs_path,
|
||||
filename=f"model_{model_id}_{model['name']}.pt",
|
||||
media_type="application/octet-stream"
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 对外提供当前模型
|
||||
def get_current_yolo_model():
|
||||
"""供检测模块获取当前加载的模型"""
|
||||
global _yolo_model
|
||||
if not _yolo_model:
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute("SELECT path FROM model WHERE is_default = 1")
|
||||
default_model = cursor.fetchone()
|
||||
if not default_model:
|
||||
print("[get_current_yolo_model] 暂无默认模型")
|
||||
return None
|
||||
|
||||
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if _yolo_model:
|
||||
print(f"[get_current_yolo_model] 自动加载默认模型成功")
|
||||
else:
|
||||
print(f"[get_current_yolo_model] 自动加载默认模型失败")
|
||||
except Exception as e:
|
||||
print(f"[get_current_yolo_model] 加载失败:{str(e)}")
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
return _yolo_model
|
Reference in New Issue
Block a user