可以成功动态更换yolo模型并重启服务生效
This commit is contained in:
@ -1,10 +1,13 @@
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from mysql.connector import Error as MySQLError
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# 复用项目依赖
|
||||
from ds.db import db
|
||||
@ -15,7 +18,7 @@ from schema.model_schema import (
|
||||
ModelListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
from util.model_util import load_yolo_model # 使用修复后的模型加载工具
|
||||
from util.model_util import load_yolo_model # 模型加载工具
|
||||
|
||||
# 路径配置
|
||||
CURRENT_FILE_PATH = Path(__file__).resolve()
|
||||
@ -28,14 +31,63 @@ DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep
|
||||
ALLOWED_MODEL_EXT = {"pt"}
|
||||
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
|
||||
|
||||
# 全局模型变量
|
||||
global _yolo_model
|
||||
# 全局模型变量(带版本标识)
|
||||
global _yolo_model, _current_model_version
|
||||
_yolo_model = None
|
||||
_current_model_version = None # 模型版本标识(用于检测模型是否变化)
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["模型管理"])
|
||||
|
||||
|
||||
# 工具函数:验证模型路径
|
||||
# 服务重启核心工具函数
|
||||
def restart_service():
|
||||
"""重启当前FastAPI服务进程"""
|
||||
print("\n[服务重启] 检测到默认模型更换,开始清理资源并重启...")
|
||||
try:
|
||||
# 关闭所有WebSocket连接
|
||||
try:
|
||||
from ws import connected_clients
|
||||
if connected_clients:
|
||||
print(f"[服务重启] 关闭{len(connected_clients)}个WebSocket旧连接")
|
||||
for ip, conn in list(connected_clients.items()):
|
||||
try:
|
||||
if conn.consumer_task and not conn.consumer_task.done():
|
||||
conn.consumer_task.cancel()
|
||||
conn.websocket.close(code=1001, reason="模型更新,服务重启")
|
||||
connected_clients.pop(ip)
|
||||
except Exception as e:
|
||||
print(f"[服务重启] 关闭客户端{ip}连接失败:{str(e)}")
|
||||
except ImportError:
|
||||
print("[服务重启] 未找到WebSocket连接管理模块,跳过连接关闭")
|
||||
|
||||
# 关闭数据库连接
|
||||
if hasattr(db, "close_all_connections"):
|
||||
db.close_all_connections()
|
||||
else:
|
||||
print("[警告] db模块未实现close_all_connections,可能存在连接泄漏")
|
||||
|
||||
# 启动新进程
|
||||
python_exec = sys.executable
|
||||
current_argv = sys.argv
|
||||
print(f"[服务重启] 启动新进程:{python_exec} {' '.join(current_argv)}")
|
||||
subprocess.Popen(
|
||||
[python_exec] + current_argv,
|
||||
close_fds=True,
|
||||
start_new_session=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE
|
||||
)
|
||||
|
||||
# 退出当前进程
|
||||
print("[服务重启] 新进程已启动,当前进程退出")
|
||||
sys.exit(0)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[服务重启] 重启失败:{str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"模型更换成功,但服务重启失败:{str(e)}") from e
|
||||
|
||||
|
||||
# 模型路径验证工具函数
|
||||
def get_valid_model_abs_path(relative_path: str) -> str:
|
||||
try:
|
||||
relative_path = relative_path.replace("/", os.sep)
|
||||
@ -87,6 +139,49 @@ def get_valid_model_abs_path(relative_path: str) -> str:
|
||||
) from e
|
||||
|
||||
|
||||
# 对外提供当前模型(带版本校验)
|
||||
def get_current_yolo_model():
|
||||
"""供检测模块获取当前最新默认模型(仅版本变化时重新加载)"""
|
||||
global _yolo_model, _current_model_version
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute("SELECT path FROM model WHERE is_default = 1")
|
||||
default_model = cursor.fetchone()
|
||||
if not default_model:
|
||||
print("[get_current_yolo_model] 暂无默认模型")
|
||||
return None
|
||||
|
||||
# 1. 计算当前默认模型的唯一版本标识
|
||||
# (路径哈希 + 文件修改时间戳,确保模型变化时版本变化)
|
||||
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
||||
model_stat = os.stat(valid_abs_path)
|
||||
model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
|
||||
|
||||
# 2. 版本未变化则复用已有模型(核心优化点)
|
||||
if _yolo_model and _current_model_version == model_version:
|
||||
# print(f"[get_current_yolo_model] 模型版本未变,复用缓存(版本:{_current_model_version[:10]}...)")
|
||||
return _yolo_model
|
||||
|
||||
# 3. 版本变化时重新加载模型
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if _yolo_model:
|
||||
setattr(_yolo_model, "model_path", valid_abs_path)
|
||||
_current_model_version = model_version # 更新版本标识
|
||||
print(f"[get_current_yolo_model] 模型版本更新,重新加载(新版本:{model_version[:10]}...)")
|
||||
else:
|
||||
print(f"[get_current_yolo_model] 加载最新默认模型失败:{valid_abs_path}")
|
||||
return _yolo_model
|
||||
|
||||
except Exception as e:
|
||||
print(f"[get_current_yolo_model] 加载失败:{str(e)}")
|
||||
return None
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 1. 上传模型
|
||||
@router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)")
|
||||
async def upload_model(
|
||||
@ -142,12 +237,16 @@ async def upload_model(
|
||||
if not new_model:
|
||||
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
|
||||
|
||||
# 加载默认模型
|
||||
global _yolo_model
|
||||
# 加载默认模型并更新版本
|
||||
global _yolo_model, _current_model_version
|
||||
if is_default:
|
||||
valid_abs_path = get_valid_model_abs_path(db_relative_path)
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if not _yolo_model:
|
||||
if _yolo_model:
|
||||
setattr(_yolo_model, "model_path", valid_abs_path)
|
||||
model_stat = os.stat(valid_abs_path)
|
||||
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path})"
|
||||
@ -246,11 +345,15 @@ async def get_default_model():
|
||||
raise HTTPException(status_code=404, detail="暂无默认模型")
|
||||
|
||||
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
||||
global _yolo_model
|
||||
global _yolo_model, _current_model_version
|
||||
|
||||
if not _yolo_model:
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if not _yolo_model:
|
||||
if _yolo_model:
|
||||
setattr(_yolo_model, "model_path", valid_abs_path)
|
||||
model_stat = os.stat(valid_abs_path)
|
||||
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"默认模型存在,但加载失败(路径:{valid_abs_path})"
|
||||
@ -358,11 +461,16 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest):
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
updated_model = cursor.fetchone()
|
||||
|
||||
global _yolo_model
|
||||
# 更新模型后重置版本标识
|
||||
global _yolo_model, _current_model_version
|
||||
if need_load_default:
|
||||
valid_abs_path = get_valid_model_abs_path(updated_model["path"])
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if not _yolo_model:
|
||||
if _yolo_model:
|
||||
setattr(_yolo_model, "model_path", valid_abs_path)
|
||||
model_stat = os.stat(valid_abs_path)
|
||||
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path})"
|
||||
@ -382,6 +490,96 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest):
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 5.1 更换默认模型(自动重启服务)
|
||||
@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型(自动重启服务)")
|
||||
async def set_default_model(model_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
conn.autocommit = False # 开启事务
|
||||
|
||||
# 1. 校验目标模型是否存在
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
target_model = cursor.fetchone()
|
||||
if not target_model:
|
||||
raise HTTPException(status_code=404, detail=f"目标模型不存在!ID:{model_id}")
|
||||
|
||||
# 2. 检查是否已为默认模型
|
||||
if target_model["is_default"]:
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"模型ID:{model_id} 已是默认模型,无需更换和重启",
|
||||
data=ModelResponse(**target_model)
|
||||
)
|
||||
|
||||
# 3. 校验目标模型文件合法性
|
||||
try:
|
||||
valid_abs_path = get_valid_model_abs_path(target_model["path"])
|
||||
except HTTPException as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"目标模型文件非法,无法设为默认:{e.detail}"
|
||||
) from e
|
||||
|
||||
# 4. 数据库事务:更新默认模型状态
|
||||
try:
|
||||
cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP")
|
||||
cursor.execute(
|
||||
"UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s",
|
||||
(model_id,)
|
||||
)
|
||||
conn.commit()
|
||||
except MySQLError as e:
|
||||
conn.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"更新默认模型状态失败(已回滚):{str(e)}"
|
||||
) from e
|
||||
|
||||
# 5. 验证新模型可加载性
|
||||
test_model = load_yolo_model(valid_abs_path)
|
||||
if not test_model:
|
||||
conn.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"新默认模型加载失败,已回滚状态(路径:{valid_abs_path})"
|
||||
)
|
||||
|
||||
# 6. 重新查询更新后的模型信息
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
updated_model = cursor.fetchone()
|
||||
|
||||
# 7. 重置版本标识(关键:确保下次检测加载新模型)
|
||||
global _current_model_version
|
||||
_current_model_version = None
|
||||
print(f"[更换默认模型] 已重置模型版本标识,下次检测将加载新模型")
|
||||
|
||||
# 8. 延迟重启服务
|
||||
print(f"[更换默认模型] 成功!将在1秒后重启服务以应用新模型(ID:{model_id})")
|
||||
threading.Timer(
|
||||
interval=1.0,
|
||||
function=restart_service
|
||||
).start()
|
||||
|
||||
# 9. 返回成功响应
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"已成功更换默认模型(ID:{model_id})!服务将在1秒后自动重启以应用新模型",
|
||||
data=ModelResponse(**updated_model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
if conn:
|
||||
conn.autocommit = True
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 6. 删除模型
|
||||
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
|
||||
async def delete_model(model_id: int):
|
||||
@ -420,10 +618,12 @@ async def delete_model(model_id: int):
|
||||
except Exception as e:
|
||||
extra_msg = f"(文件删除失败:{str(e)})"
|
||||
|
||||
global _yolo_model
|
||||
if _yolo_model and str(_yolo_model.model_path) == model_abs_path_str:
|
||||
# 如果删除的是当前加载的模型,重置缓存
|
||||
global _yolo_model, _current_model_version
|
||||
if _yolo_model and str(getattr(_yolo_model, "model_path", "")) == model_abs_path_str:
|
||||
_yolo_model = None
|
||||
print(f"[模型删除] 已清空全局模型(路径:{model_abs_path_str})")
|
||||
_current_model_version = None
|
||||
print(f"[模型删除] 已清空全局模型缓存(路径:{model_abs_path_str})")
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
@ -466,32 +666,3 @@ async def download_model(model_id: int):
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 对外提供当前模型
|
||||
def get_current_yolo_model():
|
||||
"""供检测模块获取当前加载的模型"""
|
||||
global _yolo_model
|
||||
if not _yolo_model:
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute("SELECT path FROM model WHERE is_default = 1")
|
||||
default_model = cursor.fetchone()
|
||||
if not default_model:
|
||||
print("[get_current_yolo_model] 暂无默认模型")
|
||||
return None
|
||||
|
||||
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
||||
_yolo_model = load_yolo_model(valid_abs_path)
|
||||
if _yolo_model:
|
||||
print(f"[get_current_yolo_model] 自动加载默认模型成功")
|
||||
else:
|
||||
print(f"[get_current_yolo_model] 自动加载默认模型失败")
|
||||
except Exception as e:
|
||||
print(f"[get_current_yolo_model] 加载失败:{str(e)}")
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
return _yolo_model
|
||||
|
Reference in New Issue
Block a user