可以成功动态更换yolo模型并重启服务生效

This commit is contained in:
2025-09-12 18:28:43 +08:00
parent 4be7f7bf14
commit 206652d6bb
6 changed files with 499 additions and 123 deletions

View File

@ -1,43 +1,35 @@
import os import os
import numpy as np import numpy as np
from ultralytics import YOLO from ultralytics import YOLO
from service.model_service import get_current_yolo_model # 从模型管理模块获取模型 from service.model_service import get_current_yolo_model # 带版本校验的模型获取
# 全局模型变量
_yolo_model = None
def load_model(model_path=None): def load_model(model_path=None):
"""加载YOLO模型优先使用模型管理模块的默认模型)""" """加载YOLO模型优先使用带版本校验的默认模型)"""
global _yolo_model
if model_path is None: if model_path is None:
_yolo_model = get_current_yolo_model() # 调用带版本校验的模型获取函数(自动判断是否需要重新加载)
return _yolo_model is not None return get_current_yolo_model()
try: try:
_yolo_model = YOLO(model_path) # 加载指定路径模型(用于特殊场景)
return True return YOLO(model_path)
except Exception as e: except Exception as e:
print(f"YOLO模型加载失败指定路径{str(e)}") print(f"YOLO模型加载失败指定路径{str(e)}")
return False return None
def detect(frame, conf_threshold=0.2): def detect(frame, conf_threshold=0.7):
"""执行目标检测,返回(是否成功, 结果字符串)""" """执行目标检测(仅模型版本变化时重新加载,平时复用缓存)"""
global _yolo_model # 获取模型(内部已做版本校验,未变化则直接返回缓存)
current_model = load_model()
# 确保模型已加载 if not current_model:
if not _yolo_model: return (False, "未加载到最新YOLO模型")
if not load_model():
return (False, "模型未初始化")
if frame is None: if frame is None:
return (False, "无效输入帧") return (False, "无效输入帧")
try: try:
# 执行检测frame应为numpy数组 # 用当前模型执行检测(复用缓存,无额外加载耗时
results = _yolo_model(frame, conf=conf_threshold, verbose=False) results = current_model(frame, conf=conf_threshold, verbose=False)
has_results = len(results[0].boxes) > 0 if results else False has_results = len(results[0].boxes) > 0 if results else False
if not has_results: if not has_results:
@ -49,11 +41,17 @@ def detect(frame, conf_threshold=0.2):
cls = int(box.cls[0]) cls = int(box.cls[0])
conf = float(box.conf[0]) conf = float(box.conf[0])
bbox = [round(x, 2) for x in box.xyxy[0].tolist()] # 保留两位小数 bbox = [round(x, 2) for x in box.xyxy[0].tolist()] # 保留两位小数
class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}" # 从当前模型中获取类别名(确保与模型匹配)
class_name = current_model.names[cls] if hasattr(current_model, 'names') else f"类别{cls}"
result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox}") result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox}")
# 打印当前使用的模型路径和版本(用于验证)
# model_path = getattr(current_model, "model_path", "未知路径")
# from service.model_service import _current_model_version
# print(f"[YOLO检测] 使用模型:{model_path}(版本:{_current_model_version[:10]}...")
return (True, "; ".join(result_parts)) return (True, "; ".join(result_parts))
except Exception as e: except Exception as e:
print(f"检测过程出错:{str(e)}") print(f"YOLO检测过程出错:{str(e)}")
return (False, f"检测错误:{str(e)}") return (False, f"检测错误:{str(e)}")

View File

@ -3,6 +3,8 @@ from mysql.connector import Error
from .config import MYSQL_CONFIG from .config import MYSQL_CONFIG
# 关键:声明类级别的连接池实例(必须有这一行!)
_connection_pool = None # 确保这一行存在,且拼写正确
class Database: class Database:
"""MySQL 连接池管理类""" """MySQL 连接池管理类"""
@ -41,6 +43,18 @@ class Database:
except Error as e: except Error as e:
raise Exception(f"MySQL 连接关闭失败: {str(e)}") from e raise Exception(f"MySQL 连接关闭失败: {str(e)}") from e
@classmethod
def close_all_connections(cls):
"""清理连接池(服务重启前调用)"""
try:
# 先检查属性是否存在,再判断是否有值
if hasattr(cls, "_connection_pool") and cls._connection_pool:
cls._connection_pool = None # 重置连接池
print("[Database] 连接池已重置,旧连接将被自动清理")
else:
print("[Database] 连接池未初始化或已重置,无需操作")
except Exception as e:
print(f"[Database] 重置连接池失败: {str(e)}")
# 暴露数据库操作工具 # 暴露数据库操作工具
db = Database() db = Database()

View File

@ -8,32 +8,48 @@ from pydantic import BaseModel, Field
# 请求模型 # 请求模型
# ------------------------------ # ------------------------------
class DeviceCreateRequest(BaseModel): class DeviceCreateRequest(BaseModel):
"""设备流信息创建请求模型(与数据库表字段对齐)""" """设备创建请求模型"""
ip: Optional[str] = Field(..., max_length=100, description="设备IP地址") ip: Optional[str] = Field(..., max_length=100, description="设备IP地址")
hostname: Optional[str] = Field(None, max_length=100, description="设备别名") hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
params: Optional[Dict] = Field(None, description="设备详细信息JSON格式") params: Optional[Dict] = Field(None, description="设备扩展参数JSON格式")
# ------------------------------ # ------------------------------
# 响应模型(后端返回数据)- 严格对齐数据库表字段 # 响应模型
# ------------------------------ # ------------------------------
class DeviceResponse(BaseModel): class DeviceResponse(BaseModel):
"""设备信息响应模型(与数据库表字段完全一致""" """设备信息响应模型(与数据库表字段对齐"""
id: int = Field(..., description="设备主键ID") id: int = Field(..., description="设备主键ID")
client_ip: Optional[str] = Field(None, max_length=100, description="设备IP地址") client_ip: Optional[str] = Field(None, max_length=100, description="设备IP地址")
hostname: Optional[str] = Field(None, max_length=100, description="设备别名") hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
device_online_status: int = Field(..., description="设备在线状态1-在线、0-离线)") device_online_status: int = Field(..., description="在线状态1-在线、0-离线)")
device_type: Optional[str] = Field(None, description="设备类型") device_type: Optional[str] = Field(None, description="设备类型")
alarm_count: int = Field(..., description="报警次数") alarm_count: int = Field(..., description="报警次数")
params: Optional[str] = Field(None, description="设备详细信息JSON字符串") params: Optional[str] = Field(None, description="扩展参数JSON字符串")
created_at: datetime = Field(..., description="记录创建时间") created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间") updated_at: datetime = Field(..., description="记录更新时间")
# 支持从数据库查询结果直接转换 model_config = {"from_attributes": True} # 支持从数据库结果直接转换
model_config = {"from_attributes": True}
class DeviceListResponse(BaseModel): class DeviceListResponse(BaseModel):
"""设备流信息列表响应模型""" """设备列表响应模型"""
total: int = Field(..., description="设备总数") total: int = Field(..., description="设备总数")
devices: List[DeviceResponse] = Field(..., description="设备列表") devices: List[DeviceResponse] = Field(..., description="设备列表")
class DeviceStatusHistoryResponse(BaseModel):
"""设备上下线记录响应模型"""
id: int = Field(..., description="记录ID")
device_id: int = Field(..., description="关联设备ID")
client_ip: Optional[str] = Field(None, description="设备IP地址")
status: int = Field(..., description="状态1-在线、0-离线)")
status_time: datetime = Field(..., description="状态变更时间")
model_config = {"from_attributes": True}
class DeviceStatusHistoryListResponse(BaseModel):
"""设备上下线记录列表响应模型"""
total: int = Field(..., description="记录总数")
history: List[DeviceStatusHistoryResponse] = Field(..., description="上下线记录列表")

View File

@ -1,10 +1,14 @@
import json import json
from datetime import datetime, date
from fastapi import APIRouter, Query, HTTPException,Request from fastapi import APIRouter, Query, HTTPException, Request, Path
from mysql.connector import Error as MySQLError from mysql.connector import Error as MySQLError
from ds.db import db from ds.db import db
from schema.device_schema import DeviceCreateRequest, DeviceResponse, DeviceListResponse from schema.device_schema import (
DeviceCreateRequest, DeviceResponse, DeviceListResponse,
DeviceStatusHistoryResponse, DeviceStatusHistoryListResponse
)
from schema.response_schema import APIResponse from schema.response_schema import APIResponse
router = APIRouter( router = APIRouter(
@ -13,12 +17,53 @@ router = APIRouter(
) )
# ------------------------------
# 内部工具方法 - 记录设备状态变更历史
# ------------------------------
def record_status_change(client_ip: str, status: int) -> bool:
"""
记录设备状态变更历史(写入 device_action 表)
:param client_ip: 设备IP
:param status: 状态1-在线、0-离线)
:return: 操作是否成功
"""
if not client_ip:
raise ValueError("客户端IP不能为空")
if status not in (0, 1):
raise ValueError("状态必须是0离线或1在线")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 插入状态变更记录到 device_action
insert_query = """
INSERT INTO device_action
(client_ip, action, created_at, updated_at)
VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
"""
cursor.execute(insert_query, (client_ip, status))
conn.commit()
return True
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"记录设备状态变更失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------ # ------------------------------
# 内部工具方法 - 通过客户端IP增加设备报警次数 # 内部工具方法 - 通过客户端IP增加设备报警次数
# ------------------------------ # ------------------------------
def increment_alarm_count_by_ip(client_ip: str) -> bool: def increment_alarm_count_by_ip(client_ip: str) -> bool:
""" """
通过客户端IP增加设备的报警次数(内部服务方法) 通过客户端IP增加设备的报警次数
:param client_ip: 客户端IP地址 :param client_ip: 客户端IP地址
:return: 操作是否成功 :return: 操作是否成功
@ -34,7 +79,8 @@ def increment_alarm_count_by_ip(client_ip: str) -> bool:
# 检查设备是否存在 # 检查设备是否存在
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,)) cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
if not cursor.fetchone(): device = cursor.fetchone()
if not device:
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在") raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
# 报警次数加1、并更新时间戳 # 报警次数加1、并更新时间戳
@ -61,7 +107,7 @@ def increment_alarm_count_by_ip(client_ip: str) -> bool:
# ------------------------------ # ------------------------------
def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
""" """
通过客户端IP更新设备的在线状态(内部服务方法) 通过客户端IP更新设备的在线状态
:param client_ip: 客户端IP地址 :param client_ip: 客户端IP地址
:param online_status: 在线状态1-在线、0-离线) :param online_status: 在线状态1-在线、0-离线)
@ -70,7 +116,6 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
if not client_ip: if not client_ip:
raise ValueError("客户端IP不能为空") raise ValueError("客户端IP不能为空")
# 验证状态值有效性
if online_status not in (0, 1): if online_status not in (0, 1):
raise ValueError("在线状态必须是0离线或1在线") raise ValueError("在线状态必须是0离线或1在线")
@ -80,11 +125,16 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 检查设备是否存在 # 检查设备是否存在并获取设备ID
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,)) cursor.execute("SELECT id, device_online_status FROM devices WHERE client_ip = %s", (client_ip,))
if not cursor.fetchone(): device = cursor.fetchone()
if not device:
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在") raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
# 状态无变化则不操作
if device['device_online_status'] == online_status:
return True
# 更新在线状态和时间戳 # 更新在线状态和时间戳
update_query = """ update_query = """
UPDATE devices UPDATE devices
@ -93,8 +143,11 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
WHERE client_ip = %s WHERE client_ip = %s
""" """
cursor.execute(update_query, (online_status, client_ip)) cursor.execute(update_query, (online_status, client_ip))
conn.commit()
# 记录状态变更历史
record_status_change(client_ip, online_status)
conn.commit()
return True return True
except MySQLError as e: except MySQLError as e:
if conn: if conn:
@ -105,46 +158,43 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
# ------------------------------ # ------------------------------
# 原有接口保持不变 # 创建设备信息接口
# ------------------------------ # ------------------------------
@router.post("/add", response_model=APIResponse, summary="创建设备信息") @router.post("/add", response_model=APIResponse, summary="创建设备信息")
async def create_device(device_data: DeviceCreateRequest, request: Request): # 注入Request对象 async def create_device(device_data: DeviceCreateRequest, request: Request):
# 原有代码保持不变
conn = None conn = None
cursor = None cursor = None
try: try:
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 检查设备是否已存在
cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,)) cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,))
existing_device = cursor.fetchone() existing_device = cursor.fetchone()
if existing_device: if existing_device:
# 更新设备状态为在线 # 更新设备为在线状态
update_online_status_by_ip(client_ip=device_data.ip, online_status=1) update_online_status_by_ip(client_ip=device_data.ip, online_status=1)
# 返回信息
return APIResponse( return APIResponse(
code=200, code=200,
message=f"设备IP {device_data.ip} 已存在、返回已有设备信息", message=f"设备IP {device_data.ip} 已存在、返回已有设备信息",
data=DeviceResponse(**existing_device) data=DeviceResponse(**existing_device)
) )
# 直接使用注入的request对象获取用户代理 # 通过 User-Agent 判断设备类型
user_agent = request.headers.get("User-Agent", "").lower() user_agent = request.headers.get("User-Agent", "").lower()
device_type = "unknown"
if user_agent == "default": if user_agent == "default":
device_type = device_data.params.get("os") if ( device_type = device_data.params.get("os") if (device_data.params and isinstance(device_data.params, dict)) else "unknown"
device_data.params and isinstance(device_data.params, dict)) else "unknown"
elif "windows" in user_agent: elif "windows" in user_agent:
device_type = "windows" device_type = "windows"
elif "android" in user_agent: elif "android" in user_agent:
device_type = "android" device_type = "android"
elif "linux" in user_agent: elif "linux" in user_agent:
device_type = "linux" device_type = "linux"
else:
device_type = "unknown"
device_params_json = json.dumps(device_data.params) if device_data.params else None device_params_json = json.dumps(device_data.params) if device_data.params else None
# 插入新设备
insert_query = """ insert_query = """
INSERT INTO devices INSERT INTO devices
(client_ip, hostname, device_online_status, device_type, alarm_count, params) (client_ip, hostname, device_online_status, device_type, alarm_count, params)
@ -160,10 +210,14 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): #
)) ))
conn.commit() conn.commit()
# 获取新设备并返回
device_id = cursor.lastrowid device_id = cursor.lastrowid
cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,)) cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,))
new_device = cursor.fetchone() new_device = cursor.fetchone()
# 记录上线历史
record_status_change(device_data.ip, 1)
return APIResponse( return APIResponse(
code=200, code=200,
message="设备创建成功", message="设备创建成功",
@ -175,7 +229,7 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): #
conn.rollback() conn.rollback()
raise Exception(f"创建设备失败: {str(e)}") from e raise Exception(f"创建设备失败: {str(e)}") from e
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise Exception(f"设备详细信息JSON序列化失败: {str(e)}") from e raise Exception(f"设备参数JSON序列化失败: {str(e)}") from e
except Exception as e: except Exception as e:
if conn: if conn:
conn.rollback() conn.rollback()
@ -183,14 +237,17 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): #
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# ------------------------------
# 获取设备列表接口
# ------------------------------
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)") @router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
async def get_device_list( async def get_device_list(
page: int = Query(1, ge=1, description="页码默认第1页"), page: int = Query(1, ge=1, description="页码默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100之间"), page_size: int = Query(10, ge=1, le=100, description="每页条数1-100之间"),
device_type: str = Query(None, description="按设备类型筛选"), device_type: str = Query(None, description="按设备类型筛选"),
online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选") online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选")
): ):
# 原有代码保持不变
conn = None conn = None
cursor = None cursor = None
try: try:
@ -207,12 +264,14 @@ async def get_device_list(
where_clause.append("device_online_status = %s") where_clause.append("device_online_status = %s")
params.append(online_status) params.append(online_status)
# 统计总数
count_query = "SELECT COUNT(*) AS total FROM devices" count_query = "SELECT COUNT(*) AS total FROM devices"
if where_clause: if where_clause:
count_query += " WHERE " + " AND ".join(where_clause) count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params) cursor.execute(count_query, params)
total = cursor.fetchone()["total"] total = cursor.fetchone()["total"]
# 分页查询列表
offset = (page - 1) * page_size offset = (page - 1) * page_size
list_query = "SELECT * FROM devices" list_query = "SELECT * FROM devices"
if where_clause: if where_clause:
@ -238,23 +297,140 @@ async def get_device_list(
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
def get_unique_client_ips() -> list[str]: # ------------------------------
""" # 获取设备上下线记录接口
获取所有去重的客户端IP列表 # ------------------------------
@router.get("/{device_id}/status-history", response_model=APIResponse, summary="获取设备上下线记录")
:return: 去重后的客户端IP字符串列表如果没有数据则返回空列表 async def get_device_status_history(
""" device_id: int = Path(..., description="设备ID"),
page: int = Query(1, ge=1, description="页码默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数1-100之间"),
start_date: date = Query(None, description="开始日期格式YYYY-MM-DD"),
end_date: date = Query(None, description="结束日期格式YYYY-MM-DD")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查设备是否存在并获取 client_ip
cursor.execute("SELECT id, client_ip FROM devices WHERE id = %s", (device_id,))
device = cursor.fetchone()
if not device:
raise HTTPException(status_code=404, detail=f"设备ID为 {device_id} 的设备不存在")
client_ip = device['client_ip']
where_clause = ["client_ip = %s"]
params = [client_ip]
# 日期筛选
if start_date:
where_clause.append("DATE(created_at) >= %s")
params.append(start_date.strftime("%Y-%m-%d"))
if end_date:
where_clause.append("DATE(created_at) <= %s")
params.append(end_date.strftime("%Y-%m-%d"))
# 统计记录总数
count_query = "SELECT COUNT(*) AS total FROM device_action WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 分页查询记录
offset = (page - 1) * page_size
list_query = f"""
SELECT * FROM device_action
WHERE {' AND '.join(where_clause)}
ORDER BY created_at DESC
LIMIT %s OFFSET %s
"""
params.extend([page_size, offset])
cursor.execute(list_query, params)
history_list = cursor.fetchall()
# 格式化为响应模型结构
formatted_history = []
for item in history_list:
formatted_item = {
"id": item["id"],
"device_id": device_id,
"client_ip": item["client_ip"],
"status": item["action"],
"status_time": item["created_at"]
}
formatted_history.append(formatted_item)
return APIResponse(
code=200,
message="获取设备上下线记录成功",
data=DeviceStatusHistoryListResponse(
total=total,
history=[DeviceStatusHistoryResponse(**item) for item in formatted_history]
)
)
except MySQLError as e:
raise Exception(f"获取设备上下线记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 手动更新设备在线状态接口
# ------------------------------
@router.put("/{device_id}/status", response_model=APIResponse, summary="更新设备在线状态")
async def update_device_status(
device_id: int = Path(..., description="设备ID"),
status: int = Query(..., ge=0, le=1, description="在线状态1-在线、0-离线)")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 获取设备 client_ip
cursor.execute("SELECT id, client_ip FROM devices WHERE id = %s", (device_id,))
device = cursor.fetchone()
if not device:
raise HTTPException(status_code=404, detail=f"设备ID为 {device_id} 的设备不存在")
# 更新状态
success = update_online_status_by_ip(device['client_ip'], status)
if success:
status_text = "在线" if status == 1 else "离线"
return APIResponse(
code=200,
message=f"设备已更新为{status_text}状态",
data={"device_id": device_id, "status": status, "status_text": status_text}
)
return APIResponse(
code=500,
message="更新设备状态失败",
data=None
)
except MySQLError as e:
raise Exception(f"更新设备状态失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 获取所有去重的客户端IP列表
# ------------------------------
def get_unique_client_ips() -> list[str]:
"""获取所有去重的客户端IP列表"""
conn = None conn = None
cursor = None cursor = None
try: try:
conn = db.get_connection() conn = db.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
# 查询去重的客户端IP
query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL" query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL"
cursor.execute(query) cursor.execute(query)
# 提取结果并转换为字符串列表
results = cursor.fetchall() results = cursor.fetchall()
return [item['client_ip'] for item in results] return [item['client_ip'] for item in results]

View File

@ -1,10 +1,13 @@
import subprocess
import os
import sys
import shutil
import threading
from pathlib import Path
from datetime import datetime
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from mysql.connector import Error as MySQLError from mysql.connector import Error as MySQLError
import os
import shutil
from pathlib import Path
from datetime import datetime
# 复用项目依赖 # 复用项目依赖
from ds.db import db from ds.db import db
@ -15,7 +18,7 @@ from schema.model_schema import (
ModelListResponse ModelListResponse
) )
from schema.response_schema import APIResponse from schema.response_schema import APIResponse
from util.model_util import load_yolo_model # 使用修复后的模型加载工具 from util.model_util import load_yolo_model # 模型加载工具
# 路径配置 # 路径配置
CURRENT_FILE_PATH = Path(__file__).resolve() CURRENT_FILE_PATH = Path(__file__).resolve()
@ -28,14 +31,63 @@ DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep
ALLOWED_MODEL_EXT = {"pt"} ALLOWED_MODEL_EXT = {"pt"}
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
# 全局模型变量 # 全局模型变量(带版本标识)
global _yolo_model global _yolo_model, _current_model_version
_yolo_model = None _yolo_model = None
_current_model_version = None # 模型版本标识(用于检测模型是否变化)
router = APIRouter(prefix="/models", tags=["模型管理"]) router = APIRouter(prefix="/models", tags=["模型管理"])
# 工具函数:验证模型路径 # 服务重启核心工具函数
def restart_service():
"""重启当前FastAPI服务进程"""
print("\n[服务重启] 检测到默认模型更换,开始清理资源并重启...")
try:
# 关闭所有WebSocket连接
try:
from ws import connected_clients
if connected_clients:
print(f"[服务重启] 关闭{len(connected_clients)}个WebSocket旧连接")
for ip, conn in list(connected_clients.items()):
try:
if conn.consumer_task and not conn.consumer_task.done():
conn.consumer_task.cancel()
conn.websocket.close(code=1001, reason="模型更新,服务重启")
connected_clients.pop(ip)
except Exception as e:
print(f"[服务重启] 关闭客户端{ip}连接失败:{str(e)}")
except ImportError:
print("[服务重启] 未找到WebSocket连接管理模块跳过连接关闭")
# 关闭数据库连接
if hasattr(db, "close_all_connections"):
db.close_all_connections()
else:
print("[警告] db模块未实现close_all_connections可能存在连接泄漏")
# 启动新进程
python_exec = sys.executable
current_argv = sys.argv
print(f"[服务重启] 启动新进程:{python_exec} {' '.join(current_argv)}")
subprocess.Popen(
[python_exec] + current_argv,
close_fds=True,
start_new_session=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# 退出当前进程
print("[服务重启] 新进程已启动,当前进程退出")
sys.exit(0)
except Exception as e:
print(f"[服务重启] 重启失败:{str(e)}")
raise HTTPException(status_code=500, detail=f"模型更换成功,但服务重启失败:{str(e)}") from e
# 模型路径验证工具函数
def get_valid_model_abs_path(relative_path: str) -> str: def get_valid_model_abs_path(relative_path: str) -> str:
try: try:
relative_path = relative_path.replace("/", os.sep) relative_path = relative_path.replace("/", os.sep)
@ -87,6 +139,49 @@ def get_valid_model_abs_path(relative_path: str) -> str:
) from e ) from e
# 对外提供当前模型(带版本校验)
def get_current_yolo_model():
"""供检测模块获取当前最新默认模型(仅版本变化时重新加载)"""
global _yolo_model, _current_model_version
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT path FROM model WHERE is_default = 1")
default_model = cursor.fetchone()
if not default_model:
print("[get_current_yolo_model] 暂无默认模型")
return None
# 1. 计算当前默认模型的唯一版本标识
# (路径哈希 + 文件修改时间戳,确保模型变化时版本变化)
valid_abs_path = get_valid_model_abs_path(default_model["path"])
model_stat = os.stat(valid_abs_path)
model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
# 2. 版本未变化则复用已有模型(核心优化点)
if _yolo_model and _current_model_version == model_version:
# print(f"[get_current_yolo_model] 模型版本未变,复用缓存(版本:{_current_model_version[:10]}...")
return _yolo_model
# 3. 版本变化时重新加载模型
_yolo_model = load_yolo_model(valid_abs_path)
if _yolo_model:
setattr(_yolo_model, "model_path", valid_abs_path)
_current_model_version = model_version # 更新版本标识
print(f"[get_current_yolo_model] 模型版本更新,重新加载(新版本:{model_version[:10]}...")
else:
print(f"[get_current_yolo_model] 加载最新默认模型失败:{valid_abs_path}")
return _yolo_model
except Exception as e:
print(f"[get_current_yolo_model] 加载失败:{str(e)}")
return None
finally:
db.close_connection(conn, cursor)
# 1. 上传模型 # 1. 上传模型
@router.post("", response_model=APIResponse, summary="上传YOLO模型.pt格式") @router.post("", response_model=APIResponse, summary="上传YOLO模型.pt格式")
async def upload_model( async def upload_model(
@ -142,12 +237,16 @@ async def upload_model(
if not new_model: if not new_model:
raise HTTPException(status_code=500, detail="上传成功但无法获取记录") raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
# 加载默认模型 # 加载默认模型并更新版本
global _yolo_model global _yolo_model, _current_model_version
if is_default: if is_default:
valid_abs_path = get_valid_model_abs_path(db_relative_path) valid_abs_path = get_valid_model_abs_path(db_relative_path)
_yolo_model = load_yolo_model(valid_abs_path) _yolo_model = load_yolo_model(valid_abs_path)
if not _yolo_model: if _yolo_model:
setattr(_yolo_model, "model_path", valid_abs_path)
model_stat = os.stat(valid_abs_path)
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
else:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path}" detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path}"
@ -246,11 +345,15 @@ async def get_default_model():
raise HTTPException(status_code=404, detail="暂无默认模型") raise HTTPException(status_code=404, detail="暂无默认模型")
valid_abs_path = get_valid_model_abs_path(default_model["path"]) valid_abs_path = get_valid_model_abs_path(default_model["path"])
global _yolo_model global _yolo_model, _current_model_version
if not _yolo_model: if not _yolo_model:
_yolo_model = load_yolo_model(valid_abs_path) _yolo_model = load_yolo_model(valid_abs_path)
if not _yolo_model: if _yolo_model:
setattr(_yolo_model, "model_path", valid_abs_path)
model_stat = os.stat(valid_abs_path)
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
else:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"默认模型存在,但加载失败(路径:{valid_abs_path}" detail=f"默认模型存在,但加载失败(路径:{valid_abs_path}"
@ -358,11 +461,16 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest):
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
updated_model = cursor.fetchone() updated_model = cursor.fetchone()
global _yolo_model # 更新模型后重置版本标识
global _yolo_model, _current_model_version
if need_load_default: if need_load_default:
valid_abs_path = get_valid_model_abs_path(updated_model["path"]) valid_abs_path = get_valid_model_abs_path(updated_model["path"])
_yolo_model = load_yolo_model(valid_abs_path) _yolo_model = load_yolo_model(valid_abs_path)
if not _yolo_model: if _yolo_model:
setattr(_yolo_model, "model_path", valid_abs_path)
model_stat = os.stat(valid_abs_path)
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
else:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path}" detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path}"
@ -382,6 +490,96 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest):
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# 5.1 更换默认模型(自动重启服务)
@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型(自动重启服务)")
async def set_default_model(model_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
conn.autocommit = False # 开启事务
# 1. 校验目标模型是否存在
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
target_model = cursor.fetchone()
if not target_model:
raise HTTPException(status_code=404, detail=f"目标模型不存在ID{model_id}")
# 2. 检查是否已为默认模型
if target_model["is_default"]:
return APIResponse(
code=200,
message=f"模型ID{model_id} 已是默认模型,无需更换和重启",
data=ModelResponse(**target_model)
)
# 3. 校验目标模型文件合法性
try:
valid_abs_path = get_valid_model_abs_path(target_model["path"])
except HTTPException as e:
raise HTTPException(
status_code=400,
detail=f"目标模型文件非法,无法设为默认:{e.detail}"
) from e
# 4. 数据库事务:更新默认模型状态
try:
cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP")
cursor.execute(
"UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s",
(model_id,)
)
conn.commit()
except MySQLError as e:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"更新默认模型状态失败(已回滚):{str(e)}"
) from e
# 5. 验证新模型可加载性
test_model = load_yolo_model(valid_abs_path)
if not test_model:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"新默认模型加载失败,已回滚状态(路径:{valid_abs_path}"
)
# 6. 重新查询更新后的模型信息
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
updated_model = cursor.fetchone()
# 7. 重置版本标识(关键:确保下次检测加载新模型)
global _current_model_version
_current_model_version = None
print(f"[更换默认模型] 已重置模型版本标识,下次检测将加载新模型")
# 8. 延迟重启服务
print(f"[更换默认模型] 成功将在1秒后重启服务以应用新模型ID{model_id}")
threading.Timer(
interval=1.0,
function=restart_service
).start()
# 9. 返回成功响应
return APIResponse(
code=200,
message=f"已成功更换默认模型ID{model_id}服务将在1秒后自动重启以应用新模型",
data=ModelResponse(**updated_model)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
if conn:
conn.autocommit = True
db.close_connection(conn, cursor)
# 6. 删除模型 # 6. 删除模型
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型") @router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
async def delete_model(model_id: int): async def delete_model(model_id: int):
@ -420,10 +618,12 @@ async def delete_model(model_id: int):
except Exception as e: except Exception as e:
extra_msg = f"(文件删除失败:{str(e)}" extra_msg = f"(文件删除失败:{str(e)}"
global _yolo_model # 如果删除的是当前加载的模型,重置缓存
if _yolo_model and str(_yolo_model.model_path) == model_abs_path_str: global _yolo_model, _current_model_version
if _yolo_model and str(getattr(_yolo_model, "model_path", "")) == model_abs_path_str:
_yolo_model = None _yolo_model = None
print(f"[模型删除] 已清空全局模型(路径:{model_abs_path_str}") _current_model_version = None
print(f"[模型删除] 已清空全局模型缓存(路径:{model_abs_path_str}")
return APIResponse( return APIResponse(
code=200, code=200,
@ -466,32 +666,3 @@ async def download_model(model_id: int):
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally: finally:
db.close_connection(conn, cursor) db.close_connection(conn, cursor)
# 对外提供当前模型
def get_current_yolo_model():
"""供检测模块获取当前加载的模型"""
global _yolo_model
if not _yolo_model:
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT path FROM model WHERE is_default = 1")
default_model = cursor.fetchone()
if not default_model:
print("[get_current_yolo_model] 暂无默认模型")
return None
valid_abs_path = get_valid_model_abs_path(default_model["path"])
_yolo_model = load_yolo_model(valid_abs_path)
if _yolo_model:
print(f"[get_current_yolo_model] 自动加载默认模型成功")
else:
print(f"[get_current_yolo_model] 自动加载默认模型失败")
except Exception as e:
print(f"[get_current_yolo_model] 加载失败:{str(e)}")
finally:
db.close_connection(conn, cursor)
return _yolo_model

View File

@ -12,7 +12,7 @@ def save_face_to_up_images(
) -> Dict[str, str]: ) -> Dict[str, str]:
""" """
保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径 保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径
修复路径计算错误确保所有路径在up_images根目录下 修复路径计算错误确保所有路径在up_images根目录下,且统一使用正斜杠
参数: 参数:
client_ip: 客户端IP原始格式如192.168.1.101 client_ip: 客户端IP原始格式如192.168.1.101
@ -38,7 +38,7 @@ def save_face_to_up_images(
safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|']) # 过滤非法字符 safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|']) # 过滤非法字符
# 3. 构建根目录(强制转为绝对路径,避免相对路径混淆) # 3. 构建根目录(强制转为绝对路径,避免相对路径混淆)
root_dir = Path("up_images").resolve() # 转为绝对路径(关键修复!) root_dir = Path("up_images").resolve() # 转为绝对路径
if not root_dir.exists(): if not root_dir.exists():
root_dir.mkdir(parents=True, exist_ok=True) root_dir.mkdir(parents=True, exist_ok=True)
print(f"[FileUtil] 已创建up_images根目录{root_dir}") print(f"[FileUtil] 已创建up_images根目录{root_dir}")
@ -53,15 +53,16 @@ def save_face_to_up_images(
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3] timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}" image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}"
# 6. 计算路径(关键修复:确保所有路径都是绝对路径且在root_dir下 # 6. 计算路径确保所有路径都是绝对路径且在root_dir下
local_abs_path = face_name_dir / image_filename # 绝对路径 local_abs_path = face_name_dir / image_filename # 绝对路径
# 验证路径是否在root_dir下防止路径穿越攻击 # 验证路径是否在root_dir下防止路径穿越攻击
if not local_abs_path.resolve().is_relative_to(root_dir.resolve()): if not local_abs_path.resolve().is_relative_to(root_dir.resolve()):
raise Exception(f"图片路径不在up_images根目录下安全校验失败{local_abs_path}") raise Exception(f"图片路径不在up_images根目录下安全校验失败{local_abs_path}")
# 数据库存储路径从root_dir开始的相对路径(如 up_images/192_168_110_31/小王/xxx.jpg # 数据库存储路径从root_dir开始的相对路径,强制替换为正斜杠
db_path = str(root_dir.name / local_abs_path.relative_to(root_dir)) relative_path = local_abs_path.relative_to(root_dir)
db_path = str(relative_path).replace("\\", "/") # 关键修复:统一使用正斜杠
# 7. 写入图片文件 # 7. 写入图片文件
with open(local_abs_path, "wb") as f: with open(local_abs_path, "wb") as f:
@ -72,7 +73,7 @@ def save_face_to_up_images(
return { return {
"success": True, "success": True,
"db_path": db_path, # 存数据库的相对路径(up_images开头 "db_path": db_path, # 存数据库的相对路径(使用正斜杠
"local_abs_path": str(local_abs_path), # 本地绝对路径 "local_abs_path": str(local_abs_path), # 本地绝对路径
"msg": "图片保存成功" "msg": "图片保存成功"
} }