可以成功动态更换yolo模型并重启服务生效

This commit is contained in:
2025-09-12 18:28:43 +08:00
parent 4be7f7bf14
commit 206652d6bb
6 changed files with 499 additions and 123 deletions

View File

@ -1,10 +1,13 @@
import subprocess
import os
import sys
import shutil
import threading
from pathlib import Path
from datetime import datetime
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
from fastapi.responses import FileResponse
from mysql.connector import Error as MySQLError
import os
import shutil
from pathlib import Path
from datetime import datetime
# 复用项目依赖
from ds.db import db
@ -15,7 +18,7 @@ from schema.model_schema import (
ModelListResponse
)
from schema.response_schema import APIResponse
from util.model_util import load_yolo_model # 使用修复后的模型加载工具
from util.model_util import load_yolo_model # 模型加载工具
# 路径配置
CURRENT_FILE_PATH = Path(__file__).resolve()
@ -28,14 +31,63 @@ DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep
ALLOWED_MODEL_EXT = {"pt"}
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
# 全局模型变量
global _yolo_model
# 全局模型变量(带版本标识)
global _yolo_model, _current_model_version
_yolo_model = None
_current_model_version = None # 模型版本标识(用于检测模型是否变化)
router = APIRouter(prefix="/models", tags=["模型管理"])
# 工具函数:验证模型路径
# 服务重启核心工具函数
def restart_service():
"""重启当前FastAPI服务进程"""
print("\n[服务重启] 检测到默认模型更换,开始清理资源并重启...")
try:
# 关闭所有WebSocket连接
try:
from ws import connected_clients
if connected_clients:
print(f"[服务重启] 关闭{len(connected_clients)}个WebSocket旧连接")
for ip, conn in list(connected_clients.items()):
try:
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
conn.websocket.close(code=1001, reason="模型更新,服务重启")
connected_clients.pop(ip)
except Exception as e:
print(f"[服务重启] 关闭客户端{ip}连接失败:{str(e)}")
except ImportError:
print("[服务重启] 未找到WebSocket连接管理模块跳过连接关闭")
# 关闭数据库连接
if hasattr(db, "close_all_connections"):
db.close_all_connections()
else:
print("[警告] db模块未实现close_all_connections可能存在连接泄漏")
# 启动新进程
python_exec = sys.executable
current_argv = sys.argv
print(f"[服务重启] 启动新进程:{python_exec} {' '.join(current_argv)}")
subprocess.Popen(
[python_exec] + current_argv,
close_fds=True,
start_new_session=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# 退出当前进程
print("[服务重启] 新进程已启动,当前进程退出")
sys.exit(0)
except Exception as e:
print(f"[服务重启] 重启失败:{str(e)}")
raise HTTPException(status_code=500, detail=f"模型更换成功,但服务重启失败:{str(e)}") from e
# 模型路径验证工具函数
def get_valid_model_abs_path(relative_path: str) -> str:
try:
relative_path = relative_path.replace("/", os.sep)
@ -87,6 +139,49 @@ def get_valid_model_abs_path(relative_path: str) -> str:
) from e
# 对外提供当前模型(带版本校验)
def get_current_yolo_model():
"""供检测模块获取当前最新默认模型(仅版本变化时重新加载)"""
global _yolo_model, _current_model_version
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT path FROM model WHERE is_default = 1")
default_model = cursor.fetchone()
if not default_model:
print("[get_current_yolo_model] 暂无默认模型")
return None
# 1. 计算当前默认模型的唯一版本标识
# (路径哈希 + 文件修改时间戳,确保模型变化时版本变化)
valid_abs_path = get_valid_model_abs_path(default_model["path"])
model_stat = os.stat(valid_abs_path)
model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
# 2. 版本未变化则复用已有模型(核心优化点)
if _yolo_model and _current_model_version == model_version:
# print(f"[get_current_yolo_model] 模型版本未变,复用缓存(版本:{_current_model_version[:10]}...")
return _yolo_model
# 3. 版本变化时重新加载模型
_yolo_model = load_yolo_model(valid_abs_path)
if _yolo_model:
setattr(_yolo_model, "model_path", valid_abs_path)
_current_model_version = model_version # 更新版本标识
print(f"[get_current_yolo_model] 模型版本更新,重新加载(新版本:{model_version[:10]}...")
else:
print(f"[get_current_yolo_model] 加载最新默认模型失败:{valid_abs_path}")
return _yolo_model
except Exception as e:
print(f"[get_current_yolo_model] 加载失败:{str(e)}")
return None
finally:
db.close_connection(conn, cursor)
# 1. 上传模型
@router.post("", response_model=APIResponse, summary="上传YOLO模型.pt格式")
async def upload_model(
@ -142,12 +237,16 @@ async def upload_model(
if not new_model:
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
# 加载默认模型
global _yolo_model
# 加载默认模型并更新版本
global _yolo_model, _current_model_version
if is_default:
valid_abs_path = get_valid_model_abs_path(db_relative_path)
_yolo_model = load_yolo_model(valid_abs_path)
if not _yolo_model:
if _yolo_model:
setattr(_yolo_model, "model_path", valid_abs_path)
model_stat = os.stat(valid_abs_path)
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
else:
raise HTTPException(
status_code=500,
detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path}"
@ -246,11 +345,15 @@ async def get_default_model():
raise HTTPException(status_code=404, detail="暂无默认模型")
valid_abs_path = get_valid_model_abs_path(default_model["path"])
global _yolo_model
global _yolo_model, _current_model_version
if not _yolo_model:
_yolo_model = load_yolo_model(valid_abs_path)
if not _yolo_model:
if _yolo_model:
setattr(_yolo_model, "model_path", valid_abs_path)
model_stat = os.stat(valid_abs_path)
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
else:
raise HTTPException(
status_code=500,
detail=f"默认模型存在,但加载失败(路径:{valid_abs_path}"
@ -358,11 +461,16 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest):
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
updated_model = cursor.fetchone()
global _yolo_model
# 更新模型后重置版本标识
global _yolo_model, _current_model_version
if need_load_default:
valid_abs_path = get_valid_model_abs_path(updated_model["path"])
_yolo_model = load_yolo_model(valid_abs_path)
if not _yolo_model:
if _yolo_model:
setattr(_yolo_model, "model_path", valid_abs_path)
model_stat = os.stat(valid_abs_path)
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
else:
raise HTTPException(
status_code=500,
detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path}"
@ -382,6 +490,96 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest):
db.close_connection(conn, cursor)
# 5.1 更换默认模型(自动重启服务)
@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型(自动重启服务)")
async def set_default_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
conn.autocommit = False # 开启事务
# 1. 校验目标模型是否存在
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
target_model = cursor.fetchone()
if not target_model:
raise HTTPException(status_code=404, detail=f"目标模型不存在ID{model_id}")
# 2. 检查是否已为默认模型
if target_model["is_default"]:
return APIResponse(
code=200,
message=f"模型ID{model_id} 已是默认模型,无需更换和重启",
data=ModelResponse(**target_model)
)
# 3. 校验目标模型文件合法性
try:
valid_abs_path = get_valid_model_abs_path(target_model["path"])
except HTTPException as e:
raise HTTPException(
status_code=400,
detail=f"目标模型文件非法,无法设为默认:{e.detail}"
) from e
# 4. 数据库事务:更新默认模型状态
try:
cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP")
cursor.execute(
"UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s",
(model_id,)
)
conn.commit()
except MySQLError as e:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"更新默认模型状态失败(已回滚):{str(e)}"
) from e
# 5. 验证新模型可加载性
test_model = load_yolo_model(valid_abs_path)
if not test_model:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"新默认模型加载失败,已回滚状态(路径:{valid_abs_path}"
)
# 6. 重新查询更新后的模型信息
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
updated_model = cursor.fetchone()
# 7. 重置版本标识(关键:确保下次检测加载新模型)
global _current_model_version
_current_model_version = None
print(f"[更换默认模型] 已重置模型版本标识,下次检测将加载新模型")
# 8. 延迟重启服务
print(f"[更换默认模型] 成功将在1秒后重启服务以应用新模型ID{model_id}")
threading.Timer(
interval=1.0,
function=restart_service
).start()
# 9. 返回成功响应
return APIResponse(
code=200,
message=f"已成功更换默认模型ID{model_id}服务将在1秒后自动重启以应用新模型",
data=ModelResponse(**updated_model)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
if conn:
conn.autocommit = True
db.close_connection(conn, cursor)
# 6. 删除模型
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
async def delete_model(model_id: int):
@ -420,10 +618,12 @@ async def delete_model(model_id: int):
except Exception as e:
extra_msg = f"(文件删除失败:{str(e)}"
global _yolo_model
if _yolo_model and str(_yolo_model.model_path) == model_abs_path_str:
# 如果删除的是当前加载的模型,重置缓存
global _yolo_model, _current_model_version
if _yolo_model and str(getattr(_yolo_model, "model_path", "")) == model_abs_path_str:
_yolo_model = None
print(f"[模型删除] 已清空全局模型(路径:{model_abs_path_str}")
_current_model_version = None
print(f"[模型删除] 已清空全局模型缓存(路径:{model_abs_path_str}")
return APIResponse(
code=200,
@ -466,32 +666,3 @@ async def download_model(model_id: int):
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 对外提供当前模型
def get_current_yolo_model():
"""供检测模块获取当前加载的模型"""
global _yolo_model
if not _yolo_model:
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT path FROM model WHERE is_default = 1")
default_model = cursor.fetchone()
if not default_model:
print("[get_current_yolo_model] 暂无默认模型")
return None
valid_abs_path = get_valid_model_abs_path(default_model["path"])
_yolo_model = load_yolo_model(valid_abs_path)
if _yolo_model:
print(f"[get_current_yolo_model] 自动加载默认模型成功")
else:
print(f"[get_current_yolo_model] 自动加载默认模型失败")
except Exception as e:
print(f"[get_current_yolo_model] 加载失败:{str(e)}")
finally:
db.close_connection(conn, cursor)
return _yolo_model