Files
video/service/model_service.py
2025-09-15 18:55:21 +08:00

686 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import subprocess
import os
import sys
import shutil
import threading
from pathlib import Path
from datetime import datetime
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
from fastapi.responses import FileResponse
from mysql.connector import Error as MySQLError
# 复用项目依赖
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.model_schema import (
ModelCreateRequest,
ModelUpdateRequest,
ModelResponse,
ModelListResponse
)
from schema.response_schema import APIResponse
from util.model_util import load_yolo_model # 模型加载工具
# 路径配置
CURRENT_FILE_PATH = Path(__file__).resolve()
PROJECT_ROOT = CURRENT_FILE_PATH.parent.parent
MODEL_SAVE_ROOT = PROJECT_ROOT / "resource" / "models"
MODEL_SAVE_ROOT.mkdir(exist_ok=True, parents=True)
DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep
# 模型限制
ALLOWED_MODEL_EXT = {"pt"}
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
# 全局模型变量(带版本标识和置信度)
global _yolo_model, _current_model_version, _current_conf_threshold
_yolo_model = None
_current_model_version = None # 模型版本标识
_current_conf_threshold = 0.8 # 默认置信度初始值
router = APIRouter(prefix="/models", tags=["模型管理"])
# 服务重启核心工具函数(保持不变)
def restart_service():
"""重启当前FastAPI服务进程"""
print("\n[服务重启] 检测到默认模型更换,开始清理资源并重启...")
try:
# 关闭所有WebSocket连接
try:
from ws import connected_clients
if connected_clients:
print(f"[服务重启] 关闭{len(connected_clients)}个WebSocket旧连接")
for ip, conn in list(connected_clients.items()):
try:
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
conn.websocket.close(code=1001, reason="模型更新,服务重启")
connected_clients.pop(ip)
except Exception as e:
print(f"[服务重启] 关闭客户端{ip}连接失败:{str(e)}")
except ImportError:
print("[服务重启] 未找到WebSocket连接管理模块跳过连接关闭")
# 关闭数据库连接
if hasattr(db, "close_all_connections"):
db.close_all_connections()
else:
print("[警告] db模块未实现close_all_connections可能存在连接泄漏")
# 启动新进程
python_exec = sys.executable
current_argv = sys.argv
print(f"[服务重启] 启动新进程:{python_exec} {' '.join(current_argv)}")
subprocess.Popen(
[python_exec] + current_argv,
close_fds=True,
start_new_session=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# 退出当前进程
print("[服务重启] 新进程已启动,当前进程退出")
sys.exit(0)
except Exception as e:
print(f"[服务重启] 重启失败:{str(e)}")
raise HTTPException(status_code=500, detail=f"模型更换成功,但服务重启失败:{str(e)}") from e
# 模型路径验证工具函数(保持不变)
def get_valid_model_abs_path(relative_path: str) -> str:
try:
relative_path = relative_path.replace("/", os.sep)
model_abs_path = PROJECT_ROOT / relative_path
model_abs_path = model_abs_path.resolve()
model_abs_path_str = str(model_abs_path)
if not model_abs_path_str.startswith(str(MODEL_SAVE_ROOT)):
raise HTTPException(
status_code=400,
detail=f"模型路径非法!允许目录:{str(MODEL_SAVE_ROOT)},当前路径:{model_abs_path_str}"
)
if not model_abs_path.exists():
raise HTTPException(
status_code=404,
detail=f"模型文件不存在!路径:{model_abs_path_str}"
)
if not model_abs_path.is_file():
raise HTTPException(
status_code=400,
detail=f"路径不是文件!路径:{model_abs_path_str}"
)
file_size = model_abs_path.stat().st_size
if file_size > MAX_MODEL_SIZE:
raise HTTPException(
status_code=400,
detail=f"模型文件过大({file_size // 1024 // 1024}MB超过限制{MAX_MODEL_SIZE // 1024 // 1024}MB"
)
file_ext = model_abs_path.suffix.lower()
if file_ext not in [f".{ext}" for ext in ALLOWED_MODEL_EXT]:
raise HTTPException(
status_code=400,
detail=f"模型格式非法(仅支持{ALLOWED_MODEL_EXT})!当前格式:{file_ext}"
)
print(f"[模型路径校验] 成功!路径:{model_abs_path_str},大小:{file_size // 1024}KB")
return model_abs_path_str
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"路径处理失败:{str(e)}"
) from e
# 对外提供当前模型(带版本校验)(保持不变)
def get_current_yolo_model():
"""供检测模块获取当前最新默认模型(仅版本变化时重新加载)"""
global _yolo_model, _current_model_version
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT path FROM model WHERE is_default = 1")
default_model = cursor.fetchone()
if not default_model:
print("[get_current_yolo_model] 暂无默认模型")
return None
# 1. 计算当前默认模型的唯一版本标识
valid_abs_path = get_valid_model_abs_path(default_model["path"])
model_stat = os.stat(valid_abs_path)
model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
# 2. 版本未变化则复用已有模型
if _yolo_model and _current_model_version == model_version:
return _yolo_model
# 3. 版本变化时重新加载模型
_yolo_model = load_yolo_model(valid_abs_path)
if _yolo_model:
setattr(_yolo_model, "model_path", valid_abs_path)
_current_model_version = model_version
print(f"[get_current_yolo_model] 模型版本更新,重新加载(新版本:{model_version[:10]}...")
else:
print(f"[get_current_yolo_model] 加载最新默认模型失败:{valid_abs_path}")
return _yolo_model
except Exception as e:
print(f"[get_current_yolo_model] 加载失败:{str(e)}")
return None
finally:
db.close_connection(conn, cursor)
# 新增:获取当前置信度阈值
def get_current_conf_threshold():
"""供检测模块获取当前设置的置信度阈值"""
global _current_conf_threshold
return _current_conf_threshold
# 1. 上传模型(保持不变)
@router.post("", response_model=APIResponse, summary="上传YOLO模型.pt格式")
@encrypt_response()
async def upload_model(
name: str = Form(..., description="模型名称"),
description: str = Form(None, description="模型描述"),
is_default: bool = Form(False, description="是否设为默认模型"),
file: UploadFile = File(..., description=f"YOLO模型文件.pt最大{MAX_MODEL_SIZE // 1024 // 1024}MB")
):
conn = None
cursor = None
saved_file_path = None
try:
# 校验文件
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else ""
if file_ext not in ALLOWED_MODEL_EXT:
raise HTTPException(
status_code=400,
detail=f"仅支持{ALLOWED_MODEL_EXT}格式,当前:{file_ext}"
)
if file.size > MAX_MODEL_SIZE:
raise HTTPException(
status_code=400,
detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB当前{file.size // 1024 // 1024}MB"
)
# 保存文件
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
safe_filename = f"model_{timestamp}_{file.filename.replace(' ', '_')}"
saved_file_path = MODEL_SAVE_ROOT / safe_filename
with open(saved_file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
saved_file_path.chmod(0o644) # 设置权限
# 数据库路径处理
db_relative_path = str(saved_file_path).replace(DB_PATH_PREFIX_TO_REMOVE, "").replace(os.sep, "/")
# 数据库操作
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
if is_default:
cursor.execute("UPDATE model SET is_default = 0")
insert_sql = """
INSERT INTO model (name, path, is_default, description, file_size)
VALUES (%s, %s, %s, %s, %s)
"""
cursor.execute(insert_sql, (name, db_relative_path, 1 if is_default else 0, description, file.size))
conn.commit()
cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()")
new_model = cursor.fetchone()
if not new_model:
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
# 加载默认模型并更新版本
global _yolo_model, _current_model_version
if is_default:
valid_abs_path = get_valid_model_abs_path(db_relative_path)
_yolo_model = load_yolo_model(valid_abs_path)
if _yolo_model:
setattr(_yolo_model, "model_path", valid_abs_path)
model_stat = os.stat(valid_abs_path)
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
else:
raise HTTPException(
status_code=500,
detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path}"
)
return APIResponse(
code=200,
message=f"模型上传成功ID{new_model['id']}",
data=ModelResponse(** new_model)
)
except MySQLError as e:
if conn:
conn.rollback()
if saved_file_path and saved_file_path.exists():
saved_file_path.unlink()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
except Exception as e:
if saved_file_path and saved_file_path.exists():
saved_file_path.unlink()
raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e
finally:
await file.close()
db.close_connection(conn, cursor)
# 2. 获取模型列表(保持不变)
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
@encrypt_response()
async def get_model_list(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: str = Query(None),
is_default: bool = Query(None)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
where_clause = []
params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%")
if is_default is not None:
where_clause.append("is_default = %s")
params.append(1 if is_default else 0)
# 总记录数
count_sql = "SELECT COUNT(*) AS total FROM model"
if where_clause:
count_sql += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_sql, params)
total = cursor.fetchone()["total"]
# 分页数据
offset = (page - 1) * page_size
list_sql = "SELECT * FROM model"
if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause)
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_sql, params)
model_list = cursor.fetchall()
return APIResponse(
code=200,
message=f"获取成功!共{total}条记录",
data=ModelListResponse(
total=total,
models=[ModelResponse(** model) for model in model_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 3. 获取默认模型(保持不变)
@router.get("/default", response_model=APIResponse, summary="获取当前默认模型")
@encrypt_response()
async def get_default_model():
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE is_default = 1")
default_model = cursor.fetchone()
if not default_model:
raise HTTPException(status_code=404, detail="暂无默认模型")
valid_abs_path = get_valid_model_abs_path(default_model["path"])
global _yolo_model, _current_model_version
if not _yolo_model:
_yolo_model = load_yolo_model(valid_abs_path)
if _yolo_model:
setattr(_yolo_model, "model_path", valid_abs_path)
model_stat = os.stat(valid_abs_path)
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
else:
raise HTTPException(
status_code=500,
detail=f"默认模型存在,但加载失败(路径:{valid_abs_path}"
)
return APIResponse(
code=200,
message="默认模型查询成功",
data=ModelResponse(**default_model)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 4. 获取单个模型详情(保持不变)
@router.get("/{model_id}", response_model=APIResponse, summary="获取单个模型详情")
@encrypt_response()
async def get_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
model = cursor.fetchone()
if not model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
try:
model_abs_path = get_valid_model_abs_path(model["path"])
except HTTPException as e:
return APIResponse(
code=200,
message=f"查询成功,但路径异常:{e.detail}",
data=ModelResponse(** model)
)
return APIResponse(
code=200,
message="查询成功",
data=ModelResponse(**model)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 5. 更新模型信息(保持不变)
@router.put("/{model_id}", response_model=APIResponse, summary="更新模型信息")
@encrypt_response()
async def update_model(model_id: int, model_update: ModelUpdateRequest):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
exist_model = cursor.fetchone()
if not exist_model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
update_fields = []
params = []
if model_update.name is not None:
update_fields.append("name = %s")
params.append(model_update.name)
if model_update.description is not None:
update_fields.append("description = %s")
params.append(model_update.description)
need_load_default = False
if model_update.is_default is not None:
if model_update.is_default:
cursor.execute("UPDATE model SET is_default = 0")
update_fields.append("is_default = 1")
need_load_default = True
else:
cursor.execute("SELECT COUNT(*) AS cnt FROM model WHERE is_default = 1")
default_count = cursor.fetchone()["cnt"]
if default_count == 1 and exist_model["is_default"]:
raise HTTPException(
status_code=400,
detail="当前是唯一默认模型,不可取消!"
)
update_fields.append("is_default = 0")
if not update_fields:
raise HTTPException(status_code=400, detail="至少需提供一个更新字段")
params.append(model_id)
update_sql = f"""
UPDATE model
SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP
WHERE id = %s
"""
cursor.execute(update_sql, params)
conn.commit()
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
updated_model = cursor.fetchone()
# 更新模型后重置版本标识
global _yolo_model, _current_model_version
if need_load_default:
valid_abs_path = get_valid_model_abs_path(updated_model["path"])
_yolo_model = load_yolo_model(valid_abs_path)
if _yolo_model:
setattr(_yolo_model, "model_path", valid_abs_path)
model_stat = os.stat(valid_abs_path)
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
else:
raise HTTPException(
status_code=500,
detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path}"
)
return APIResponse(
code=200,
message="模型更新成功",
data=ModelResponse(** updated_model)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 5.1 更换默认模型(添加置信度参数)
@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型(自动重启服务)")
@encrypt_response()
async def set_default_model(
model_id: int,
conf_threshold: float = Query(0.8, ge=0.01, le=0.99, description="模型检测置信度阈值0.01-0.99")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
conn.autocommit = False # 开启事务
# 1. 校验目标模型是否存在
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
target_model = cursor.fetchone()
if not target_model:
raise HTTPException(status_code=404, detail=f"目标模型不存在ID{model_id}")
# 2. 检查是否已为默认模型
if target_model["is_default"]:
return APIResponse(
code=200,
message=f"模型ID{model_id} 已是默认模型,无需更换和重启",
data=ModelResponse(**target_model)
)
# 3. 校验目标模型文件合法性
try:
valid_abs_path = get_valid_model_abs_path(target_model["path"])
except HTTPException as e:
raise HTTPException(
status_code=400,
detail=f"目标模型文件非法,无法设为默认:{e.detail}"
) from e
# 4. 数据库事务:更新默认模型状态
try:
cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP")
cursor.execute(
"UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s",
(model_id,)
)
conn.commit()
except MySQLError as e:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"更新默认模型状态失败(已回滚):{str(e)}"
) from e
# 5. 验证新模型可加载性
test_model = load_yolo_model(valid_abs_path)
if not test_model:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"新默认模型加载失败,已回滚状态(路径:{valid_abs_path}"
)
# 6. 重新查询更新后的模型信息
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
updated_model = cursor.fetchone()
# 7. 重置版本标识和更新置信度
global _current_model_version, _current_conf_threshold
_current_model_version = None
_current_conf_threshold = conf_threshold # 保存动态置信度
print(f"[更换默认模型] 已重置模型版本标识,设置新置信度:{conf_threshold}")
# 8. 延迟重启服务
print(f"[更换默认模型] 成功将在1秒后重启服务以应用新模型ID{model_id}")
threading.Timer(
interval=1.0,
function=restart_service
).start()
# 9. 返回成功响应
return APIResponse(
code=200,
message=f"已成功更换默认模型ID{model_id}),置信度:{conf_threshold}服务将在1秒后自动重启以应用新模型",
data=ModelResponse(** updated_model)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
if conn:
conn.autocommit = True
db.close_connection(conn, cursor)
# 6. 删除模型(保持不变)
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
@encrypt_response()
async def delete_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
exist_model = cursor.fetchone()
if not exist_model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
if exist_model["is_default"]:
raise HTTPException(status_code=400, detail="默认模型不可删除!")
try:
model_abs_path_str = get_valid_model_abs_path(exist_model["path"])
model_abs_path = Path(model_abs_path_str)
except HTTPException as e:
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
conn.commit()
return APIResponse(
code=200,
message=f"记录删除成功,文件异常:{e.detail}",
data=None
)
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
conn.commit()
extra_msg = ""
try:
model_abs_path.unlink()
extra_msg = f"(已删除文件)"
except Exception as e:
extra_msg = f"(文件删除失败:{str(e)}"
# 如果删除的是当前加载的模型,重置缓存
global _yolo_model, _current_model_version
if _yolo_model and str(getattr(_yolo_model, "model_path", "")) == model_abs_path_str:
_yolo_model = None
_current_model_version = None
print(f"[模型删除] 已清空全局模型缓存(路径:{model_abs_path_str}")
return APIResponse(
code=200,
message=f"模型删除成功ID{model_id} {extra_msg}",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 7. 下载模型文件(保持不变)
@router.get("/{model_id}/download", summary="下载模型文件")
@encrypt_response()
async def download_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
model = cursor.fetchone()
if not model:
raise HTTPException(status_code=404, detail=f"模型不存在ID{model_id}")
valid_abs_path = get_valid_model_abs_path(model["path"])
model_abs_path = Path(valid_abs_path)
return FileResponse(
path=model_abs_path,
filename=f"model_{model_id}_{model['name']}.pt",
media_type="application/octet-stream"
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)