可以成功动态更换yolo模型并重启服务生效
This commit is contained in:
48
core/yolo.py
48
core/yolo.py
@ -1,43 +1,35 @@
|
|||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
from service.model_service import get_current_yolo_model # 从模型管理模块获取模型
|
from service.model_service import get_current_yolo_model # 带版本校验的模型获取
|
||||||
|
|
||||||
# 全局模型变量
|
|
||||||
_yolo_model = None
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path=None):
|
def load_model(model_path=None):
|
||||||
"""加载YOLO模型(优先使用模型管理模块的默认模型)"""
|
"""加载YOLO模型(优先使用带版本校验的默认模型)"""
|
||||||
global _yolo_model
|
|
||||||
|
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
_yolo_model = get_current_yolo_model()
|
# 调用带版本校验的模型获取函数(自动判断是否需要重新加载)
|
||||||
return _yolo_model is not None
|
return get_current_yolo_model()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_yolo_model = YOLO(model_path)
|
# 加载指定路径模型(用于特殊场景)
|
||||||
return True
|
return YOLO(model_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"YOLO模型加载失败(指定路径):{str(e)}")
|
print(f"YOLO模型加载失败(指定路径):{str(e)}")
|
||||||
return False
|
return None
|
||||||
|
|
||||||
|
|
||||||
def detect(frame, conf_threshold=0.2):
|
def detect(frame, conf_threshold=0.7):
|
||||||
"""执行目标检测,返回(是否成功, 结果字符串)"""
|
"""执行目标检测(仅模型版本变化时重新加载,平时复用缓存)"""
|
||||||
global _yolo_model
|
# 获取模型(内部已做版本校验,未变化则直接返回缓存)
|
||||||
|
current_model = load_model()
|
||||||
# 确保模型已加载
|
if not current_model:
|
||||||
if not _yolo_model:
|
return (False, "未加载到最新YOLO模型")
|
||||||
if not load_model():
|
|
||||||
return (False, "模型未初始化")
|
|
||||||
|
|
||||||
if frame is None:
|
if frame is None:
|
||||||
return (False, "无效输入帧")
|
return (False, "无效输入帧")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 执行检测(frame应为numpy数组)
|
# 用当前模型执行检测(复用缓存,无额外加载耗时)
|
||||||
results = _yolo_model(frame, conf=conf_threshold, verbose=False)
|
results = current_model(frame, conf=conf_threshold, verbose=False)
|
||||||
has_results = len(results[0].boxes) > 0 if results else False
|
has_results = len(results[0].boxes) > 0 if results else False
|
||||||
|
|
||||||
if not has_results:
|
if not has_results:
|
||||||
@ -49,11 +41,17 @@ def detect(frame, conf_threshold=0.2):
|
|||||||
cls = int(box.cls[0])
|
cls = int(box.cls[0])
|
||||||
conf = float(box.conf[0])
|
conf = float(box.conf[0])
|
||||||
bbox = [round(x, 2) for x in box.xyxy[0].tolist()] # 保留两位小数
|
bbox = [round(x, 2) for x in box.xyxy[0].tolist()] # 保留两位小数
|
||||||
class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}"
|
# 从当前模型中获取类别名(确保与模型匹配)
|
||||||
|
class_name = current_model.names[cls] if hasattr(current_model, 'names') else f"类别{cls}"
|
||||||
result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox})")
|
result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox})")
|
||||||
|
|
||||||
|
# 打印当前使用的模型路径和版本(用于验证)
|
||||||
|
# model_path = getattr(current_model, "model_path", "未知路径")
|
||||||
|
# from service.model_service import _current_model_version
|
||||||
|
# print(f"[YOLO检测] 使用模型:{model_path}(版本:{_current_model_version[:10]}...)")
|
||||||
|
|
||||||
return (True, "; ".join(result_parts))
|
return (True, "; ".join(result_parts))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"检测过程出错:{str(e)}")
|
print(f"YOLO检测过程出错:{str(e)}")
|
||||||
return (False, f"检测错误:{str(e)}")
|
return (False, f"检测错误:{str(e)}")
|
||||||
|
14
ds/db.py
14
ds/db.py
@ -3,6 +3,8 @@ from mysql.connector import Error
|
|||||||
|
|
||||||
from .config import MYSQL_CONFIG
|
from .config import MYSQL_CONFIG
|
||||||
|
|
||||||
|
# 关键:声明类级别的连接池实例(必须有这一行!)
|
||||||
|
_connection_pool = None # 确保这一行存在,且拼写正确
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
"""MySQL 连接池管理类"""
|
"""MySQL 连接池管理类"""
|
||||||
@ -41,6 +43,18 @@ class Database:
|
|||||||
except Error as e:
|
except Error as e:
|
||||||
raise Exception(f"MySQL 连接关闭失败: {str(e)}") from e
|
raise Exception(f"MySQL 连接关闭失败: {str(e)}") from e
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def close_all_connections(cls):
|
||||||
|
"""清理连接池(服务重启前调用)"""
|
||||||
|
try:
|
||||||
|
# 先检查属性是否存在,再判断是否有值
|
||||||
|
if hasattr(cls, "_connection_pool") and cls._connection_pool:
|
||||||
|
cls._connection_pool = None # 重置连接池
|
||||||
|
print("[Database] 连接池已重置,旧连接将被自动清理")
|
||||||
|
else:
|
||||||
|
print("[Database] 连接池未初始化或已重置,无需操作")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Database] 重置连接池失败: {str(e)}")
|
||||||
|
|
||||||
# 暴露数据库操作工具
|
# 暴露数据库操作工具
|
||||||
db = Database()
|
db = Database()
|
||||||
|
@ -8,32 +8,48 @@ from pydantic import BaseModel, Field
|
|||||||
# 请求模型
|
# 请求模型
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
class DeviceCreateRequest(BaseModel):
|
class DeviceCreateRequest(BaseModel):
|
||||||
"""设备流信息创建请求模型(与数据库表字段对齐)"""
|
"""设备创建请求模型"""
|
||||||
ip: Optional[str] = Field(..., max_length=100, description="设备IP地址")
|
ip: Optional[str] = Field(..., max_length=100, description="设备IP地址")
|
||||||
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
|
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
|
||||||
params: Optional[Dict] = Field(None, description="设备详细信息(JSON格式)")
|
params: Optional[Dict] = Field(None, description="设备扩展参数(JSON格式)")
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 响应模型(后端返回数据)- 严格对齐数据库表字段
|
# 响应模型
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
class DeviceResponse(BaseModel):
|
class DeviceResponse(BaseModel):
|
||||||
"""设备流信息响应模型(与数据库表字段完全一致)"""
|
"""单设备信息响应模型(与数据库表字段对齐)"""
|
||||||
id: int = Field(..., description="设备主键ID")
|
id: int = Field(..., description="设备主键ID")
|
||||||
client_ip: Optional[str] = Field(None, max_length=100, description="设备IP地址")
|
client_ip: Optional[str] = Field(None, max_length=100, description="设备IP地址")
|
||||||
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
|
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
|
||||||
device_online_status: int = Field(..., description="设备在线状态(1-在线、0-离线)")
|
device_online_status: int = Field(..., description="在线状态(1-在线、0-离线)")
|
||||||
device_type: Optional[str] = Field(None, description="设备类型")
|
device_type: Optional[str] = Field(None, description="设备类型")
|
||||||
alarm_count: int = Field(..., description="报警次数")
|
alarm_count: int = Field(..., description="报警次数")
|
||||||
params: Optional[str] = Field(None, description="设备详细信息(JSON字符串)")
|
params: Optional[str] = Field(None, description="扩展参数(JSON字符串)")
|
||||||
created_at: datetime = Field(..., description="记录创建时间")
|
created_at: datetime = Field(..., description="记录创建时间")
|
||||||
updated_at: datetime = Field(..., description="记录更新时间")
|
updated_at: datetime = Field(..., description="记录更新时间")
|
||||||
|
|
||||||
# 支持从数据库查询结果直接转换
|
model_config = {"from_attributes": True} # 支持从数据库结果直接转换
|
||||||
model_config = {"from_attributes": True}
|
|
||||||
|
|
||||||
|
|
||||||
class DeviceListResponse(BaseModel):
|
class DeviceListResponse(BaseModel):
|
||||||
"""设备流信息列表响应模型"""
|
"""设备列表响应模型"""
|
||||||
total: int = Field(..., description="设备总数")
|
total: int = Field(..., description="设备总数")
|
||||||
devices: List[DeviceResponse] = Field(..., description="设备列表")
|
devices: List[DeviceResponse] = Field(..., description="设备列表")
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceStatusHistoryResponse(BaseModel):
|
||||||
|
"""设备上下线记录响应模型"""
|
||||||
|
id: int = Field(..., description="记录ID")
|
||||||
|
device_id: int = Field(..., description="关联设备ID")
|
||||||
|
client_ip: Optional[str] = Field(None, description="设备IP地址")
|
||||||
|
status: int = Field(..., description="状态(1-在线、0-离线)")
|
||||||
|
status_time: datetime = Field(..., description="状态变更时间")
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceStatusHistoryListResponse(BaseModel):
|
||||||
|
"""设备上下线记录列表响应模型"""
|
||||||
|
total: int = Field(..., description="记录总数")
|
||||||
|
history: List[DeviceStatusHistoryResponse] = Field(..., description="上下线记录列表")
|
@ -1,10 +1,14 @@
|
|||||||
import json
|
import json
|
||||||
|
from datetime import datetime, date
|
||||||
|
|
||||||
from fastapi import APIRouter, Query, HTTPException,Request
|
from fastapi import APIRouter, Query, HTTPException, Request, Path
|
||||||
from mysql.connector import Error as MySQLError
|
from mysql.connector import Error as MySQLError
|
||||||
|
|
||||||
from ds.db import db
|
from ds.db import db
|
||||||
from schema.device_schema import DeviceCreateRequest, DeviceResponse, DeviceListResponse
|
from schema.device_schema import (
|
||||||
|
DeviceCreateRequest, DeviceResponse, DeviceListResponse,
|
||||||
|
DeviceStatusHistoryResponse, DeviceStatusHistoryListResponse
|
||||||
|
)
|
||||||
from schema.response_schema import APIResponse
|
from schema.response_schema import APIResponse
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
@ -13,12 +17,53 @@ router = APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 内部工具方法 - 记录设备状态变更历史
|
||||||
|
# ------------------------------
|
||||||
|
def record_status_change(client_ip: str, status: int) -> bool:
|
||||||
|
"""
|
||||||
|
记录设备状态变更历史(写入 device_action 表)
|
||||||
|
|
||||||
|
:param client_ip: 设备IP
|
||||||
|
:param status: 状态(1-在线、0-离线)
|
||||||
|
:return: 操作是否成功
|
||||||
|
"""
|
||||||
|
if not client_ip:
|
||||||
|
raise ValueError("客户端IP不能为空")
|
||||||
|
|
||||||
|
if status not in (0, 1):
|
||||||
|
raise ValueError("状态必须是0(离线)或1(在线)")
|
||||||
|
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
# 插入状态变更记录到 device_action
|
||||||
|
insert_query = """
|
||||||
|
INSERT INTO device_action
|
||||||
|
(client_ip, action, created_at, updated_at)
|
||||||
|
VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||||
|
"""
|
||||||
|
cursor.execute(insert_query, (client_ip, status))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
return True
|
||||||
|
except MySQLError as e:
|
||||||
|
if conn:
|
||||||
|
conn.rollback()
|
||||||
|
raise Exception(f"记录设备状态变更失败: {str(e)}") from e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 内部工具方法 - 通过客户端IP增加设备报警次数
|
# 内部工具方法 - 通过客户端IP增加设备报警次数
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
def increment_alarm_count_by_ip(client_ip: str) -> bool:
|
def increment_alarm_count_by_ip(client_ip: str) -> bool:
|
||||||
"""
|
"""
|
||||||
通过客户端IP增加设备的报警次数(内部服务方法)
|
通过客户端IP增加设备的报警次数
|
||||||
|
|
||||||
:param client_ip: 客户端IP地址
|
:param client_ip: 客户端IP地址
|
||||||
:return: 操作是否成功
|
:return: 操作是否成功
|
||||||
@ -34,7 +79,8 @@ def increment_alarm_count_by_ip(client_ip: str) -> bool:
|
|||||||
|
|
||||||
# 检查设备是否存在
|
# 检查设备是否存在
|
||||||
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
|
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
|
||||||
if not cursor.fetchone():
|
device = cursor.fetchone()
|
||||||
|
if not device:
|
||||||
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
|
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
|
||||||
|
|
||||||
# 报警次数加1、并更新时间戳
|
# 报警次数加1、并更新时间戳
|
||||||
@ -61,7 +107,7 @@ def increment_alarm_count_by_ip(client_ip: str) -> bool:
|
|||||||
# ------------------------------
|
# ------------------------------
|
||||||
def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
|
def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
|
||||||
"""
|
"""
|
||||||
通过客户端IP更新设备的在线状态(内部服务方法)
|
通过客户端IP更新设备的在线状态
|
||||||
|
|
||||||
:param client_ip: 客户端IP地址
|
:param client_ip: 客户端IP地址
|
||||||
:param online_status: 在线状态(1-在线、0-离线)
|
:param online_status: 在线状态(1-在线、0-离线)
|
||||||
@ -70,7 +116,6 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
|
|||||||
if not client_ip:
|
if not client_ip:
|
||||||
raise ValueError("客户端IP不能为空")
|
raise ValueError("客户端IP不能为空")
|
||||||
|
|
||||||
# 验证状态值有效性
|
|
||||||
if online_status not in (0, 1):
|
if online_status not in (0, 1):
|
||||||
raise ValueError("在线状态必须是0(离线)或1(在线)")
|
raise ValueError("在线状态必须是0(离线)或1(在线)")
|
||||||
|
|
||||||
@ -80,11 +125,16 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
|
|||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
# 检查设备是否存在
|
# 检查设备是否存在并获取设备ID
|
||||||
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
|
cursor.execute("SELECT id, device_online_status FROM devices WHERE client_ip = %s", (client_ip,))
|
||||||
if not cursor.fetchone():
|
device = cursor.fetchone()
|
||||||
|
if not device:
|
||||||
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
|
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
|
||||||
|
|
||||||
|
# 状态无变化则不操作
|
||||||
|
if device['device_online_status'] == online_status:
|
||||||
|
return True
|
||||||
|
|
||||||
# 更新在线状态和时间戳
|
# 更新在线状态和时间戳
|
||||||
update_query = """
|
update_query = """
|
||||||
UPDATE devices
|
UPDATE devices
|
||||||
@ -93,8 +143,11 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
|
|||||||
WHERE client_ip = %s
|
WHERE client_ip = %s
|
||||||
"""
|
"""
|
||||||
cursor.execute(update_query, (online_status, client_ip))
|
cursor.execute(update_query, (online_status, client_ip))
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
|
# 记录状态变更历史
|
||||||
|
record_status_change(client_ip, online_status)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
return True
|
return True
|
||||||
except MySQLError as e:
|
except MySQLError as e:
|
||||||
if conn:
|
if conn:
|
||||||
@ -105,46 +158,43 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# 原有接口保持不变
|
# 创建设备信息接口
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
|
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
|
||||||
async def create_device(device_data: DeviceCreateRequest, request: Request): # 注入Request对象
|
async def create_device(device_data: DeviceCreateRequest, request: Request):
|
||||||
# 原有代码保持不变
|
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
try:
|
try:
|
||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
# 检查设备是否已存在
|
||||||
cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,))
|
cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,))
|
||||||
existing_device = cursor.fetchone()
|
existing_device = cursor.fetchone()
|
||||||
if existing_device:
|
if existing_device:
|
||||||
# 更新设备状态为在线
|
# 更新设备为在线状态
|
||||||
update_online_status_by_ip(client_ip=device_data.ip, online_status=1)
|
update_online_status_by_ip(client_ip=device_data.ip, online_status=1)
|
||||||
# 返回信息
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=200,
|
code=200,
|
||||||
message=f"设备IP {device_data.ip} 已存在、返回已有设备信息",
|
message=f"设备IP {device_data.ip} 已存在、返回已有设备信息",
|
||||||
data=DeviceResponse(** existing_device)
|
data=DeviceResponse(**existing_device)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 直接使用注入的request对象获取用户代理
|
# 通过 User-Agent 判断设备类型
|
||||||
user_agent = request.headers.get("User-Agent", "").lower()
|
user_agent = request.headers.get("User-Agent", "").lower()
|
||||||
|
device_type = "unknown"
|
||||||
if user_agent == "default":
|
if user_agent == "default":
|
||||||
device_type = device_data.params.get("os") if (
|
device_type = device_data.params.get("os") if (device_data.params and isinstance(device_data.params, dict)) else "unknown"
|
||||||
device_data.params and isinstance(device_data.params, dict)) else "unknown"
|
|
||||||
elif "windows" in user_agent:
|
elif "windows" in user_agent:
|
||||||
device_type = "windows"
|
device_type = "windows"
|
||||||
elif "android" in user_agent:
|
elif "android" in user_agent:
|
||||||
device_type = "android"
|
device_type = "android"
|
||||||
elif "linux" in user_agent:
|
elif "linux" in user_agent:
|
||||||
device_type = "linux"
|
device_type = "linux"
|
||||||
else:
|
|
||||||
device_type = "unknown"
|
|
||||||
|
|
||||||
device_params_json = json.dumps(device_data.params) if device_data.params else None
|
device_params_json = json.dumps(device_data.params) if device_data.params else None
|
||||||
|
|
||||||
|
# 插入新设备
|
||||||
insert_query = """
|
insert_query = """
|
||||||
INSERT INTO devices
|
INSERT INTO devices
|
||||||
(client_ip, hostname, device_online_status, device_type, alarm_count, params)
|
(client_ip, hostname, device_online_status, device_type, alarm_count, params)
|
||||||
@ -160,10 +210,14 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): #
|
|||||||
))
|
))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
|
# 获取新设备并返回
|
||||||
device_id = cursor.lastrowid
|
device_id = cursor.lastrowid
|
||||||
cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,))
|
cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,))
|
||||||
new_device = cursor.fetchone()
|
new_device = cursor.fetchone()
|
||||||
|
|
||||||
|
# 记录上线历史
|
||||||
|
record_status_change(device_data.ip, 1)
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=200,
|
code=200,
|
||||||
message="设备创建成功",
|
message="设备创建成功",
|
||||||
@ -175,7 +229,7 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): #
|
|||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise Exception(f"创建设备失败: {str(e)}") from e
|
raise Exception(f"创建设备失败: {str(e)}") from e
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise Exception(f"设备详细信息JSON序列化失败: {str(e)}") from e
|
raise Exception(f"设备参数JSON序列化失败: {str(e)}") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
@ -183,14 +237,17 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): #
|
|||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 获取设备列表接口
|
||||||
|
# ------------------------------
|
||||||
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
|
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
|
||||||
async def get_device_list(
|
async def get_device_list(
|
||||||
page: int = Query(1, ge=1, description="页码、默认第1页"),
|
page: int = Query(1, ge=1, description="页码,默认第1页"),
|
||||||
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
|
page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"),
|
||||||
device_type: str = Query(None, description="按设备类型筛选"),
|
device_type: str = Query(None, description="按设备类型筛选"),
|
||||||
online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选")
|
online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选")
|
||||||
):
|
):
|
||||||
# 原有代码保持不变
|
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
try:
|
try:
|
||||||
@ -207,12 +264,14 @@ async def get_device_list(
|
|||||||
where_clause.append("device_online_status = %s")
|
where_clause.append("device_online_status = %s")
|
||||||
params.append(online_status)
|
params.append(online_status)
|
||||||
|
|
||||||
|
# 统计总数
|
||||||
count_query = "SELECT COUNT(*) AS total FROM devices"
|
count_query = "SELECT COUNT(*) AS total FROM devices"
|
||||||
if where_clause:
|
if where_clause:
|
||||||
count_query += " WHERE " + " AND ".join(where_clause)
|
count_query += " WHERE " + " AND ".join(where_clause)
|
||||||
cursor.execute(count_query, params)
|
cursor.execute(count_query, params)
|
||||||
total = cursor.fetchone()["total"]
|
total = cursor.fetchone()["total"]
|
||||||
|
|
||||||
|
# 分页查询列表
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
list_query = "SELECT * FROM devices"
|
list_query = "SELECT * FROM devices"
|
||||||
if where_clause:
|
if where_clause:
|
||||||
@ -238,23 +297,140 @@ async def get_device_list(
|
|||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
def get_unique_client_ips() -> list[str]:
|
# ------------------------------
|
||||||
"""
|
# 获取设备上下线记录接口
|
||||||
获取所有去重的客户端IP列表
|
# ------------------------------
|
||||||
|
@router.get("/{device_id}/status-history", response_model=APIResponse, summary="获取设备上下线记录")
|
||||||
:return: 去重后的客户端IP字符串列表,如果没有数据则返回空列表
|
async def get_device_status_history(
|
||||||
"""
|
device_id: int = Path(..., description="设备ID"),
|
||||||
|
page: int = Query(1, ge=1, description="页码,默认第1页"),
|
||||||
|
page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"),
|
||||||
|
start_date: date = Query(None, description="开始日期,格式YYYY-MM-DD"),
|
||||||
|
end_date: date = Query(None, description="结束日期,格式YYYY-MM-DD")
|
||||||
|
):
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
# 检查设备是否存在并获取 client_ip
|
||||||
|
cursor.execute("SELECT id, client_ip FROM devices WHERE id = %s", (device_id,))
|
||||||
|
device = cursor.fetchone()
|
||||||
|
if not device:
|
||||||
|
raise HTTPException(status_code=404, detail=f"设备ID为 {device_id} 的设备不存在")
|
||||||
|
client_ip = device['client_ip']
|
||||||
|
|
||||||
|
where_clause = ["client_ip = %s"]
|
||||||
|
params = [client_ip]
|
||||||
|
|
||||||
|
# 日期筛选
|
||||||
|
if start_date:
|
||||||
|
where_clause.append("DATE(created_at) >= %s")
|
||||||
|
params.append(start_date.strftime("%Y-%m-%d"))
|
||||||
|
if end_date:
|
||||||
|
where_clause.append("DATE(created_at) <= %s")
|
||||||
|
params.append(end_date.strftime("%Y-%m-%d"))
|
||||||
|
|
||||||
|
# 统计记录总数
|
||||||
|
count_query = "SELECT COUNT(*) AS total FROM device_action WHERE " + " AND ".join(where_clause)
|
||||||
|
cursor.execute(count_query, params)
|
||||||
|
total = cursor.fetchone()["total"]
|
||||||
|
|
||||||
|
# 分页查询记录
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
list_query = f"""
|
||||||
|
SELECT * FROM device_action
|
||||||
|
WHERE {' AND '.join(where_clause)}
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT %s OFFSET %s
|
||||||
|
"""
|
||||||
|
params.extend([page_size, offset])
|
||||||
|
cursor.execute(list_query, params)
|
||||||
|
history_list = cursor.fetchall()
|
||||||
|
|
||||||
|
# 格式化为响应模型结构
|
||||||
|
formatted_history = []
|
||||||
|
for item in history_list:
|
||||||
|
formatted_item = {
|
||||||
|
"id": item["id"],
|
||||||
|
"device_id": device_id,
|
||||||
|
"client_ip": item["client_ip"],
|
||||||
|
"status": item["action"],
|
||||||
|
"status_time": item["created_at"]
|
||||||
|
}
|
||||||
|
formatted_history.append(formatted_item)
|
||||||
|
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message="获取设备上下线记录成功",
|
||||||
|
data=DeviceStatusHistoryListResponse(
|
||||||
|
total=total,
|
||||||
|
history=[DeviceStatusHistoryResponse(**item) for item in formatted_history]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except MySQLError as e:
|
||||||
|
raise Exception(f"获取设备上下线记录失败: {str(e)}") from e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 手动更新设备在线状态接口
|
||||||
|
# ------------------------------
|
||||||
|
@router.put("/{device_id}/status", response_model=APIResponse, summary="更新设备在线状态")
|
||||||
|
async def update_device_status(
|
||||||
|
device_id: int = Path(..., description="设备ID"),
|
||||||
|
status: int = Query(..., ge=0, le=1, description="在线状态(1-在线、0-离线)")
|
||||||
|
):
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
|
# 获取设备 client_ip
|
||||||
|
cursor.execute("SELECT id, client_ip FROM devices WHERE id = %s", (device_id,))
|
||||||
|
device = cursor.fetchone()
|
||||||
|
if not device:
|
||||||
|
raise HTTPException(status_code=404, detail=f"设备ID为 {device_id} 的设备不存在")
|
||||||
|
|
||||||
|
# 更新状态
|
||||||
|
success = update_online_status_by_ip(device['client_ip'], status)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
status_text = "在线" if status == 1 else "离线"
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message=f"设备已更新为{status_text}状态",
|
||||||
|
data={"device_id": device_id, "status": status, "status_text": status_text}
|
||||||
|
)
|
||||||
|
return APIResponse(
|
||||||
|
code=500,
|
||||||
|
message="更新设备状态失败",
|
||||||
|
data=None
|
||||||
|
)
|
||||||
|
|
||||||
|
except MySQLError as e:
|
||||||
|
raise Exception(f"更新设备状态失败: {str(e)}") from e
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# 获取所有去重的客户端IP列表
|
||||||
|
# ------------------------------
|
||||||
|
def get_unique_client_ips() -> list[str]:
|
||||||
|
"""获取所有去重的客户端IP列表"""
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
try:
|
try:
|
||||||
conn = db.get_connection()
|
conn = db.get_connection()
|
||||||
cursor = conn.cursor(dictionary=True)
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
|
||||||
# 查询去重的客户端IP
|
|
||||||
query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL"
|
query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL"
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
|
|
||||||
# 提取结果并转换为字符串列表
|
|
||||||
results = cursor.fetchall()
|
results = cursor.fetchall()
|
||||||
return [item['client_ip'] for item in results]
|
return [item['client_ip'] for item in results]
|
||||||
|
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
|
import subprocess
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import shutil
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
|
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from mysql.connector import Error as MySQLError
|
from mysql.connector import Error as MySQLError
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
# 复用项目依赖
|
# 复用项目依赖
|
||||||
from ds.db import db
|
from ds.db import db
|
||||||
@ -15,7 +18,7 @@ from schema.model_schema import (
|
|||||||
ModelListResponse
|
ModelListResponse
|
||||||
)
|
)
|
||||||
from schema.response_schema import APIResponse
|
from schema.response_schema import APIResponse
|
||||||
from util.model_util import load_yolo_model # 使用修复后的模型加载工具
|
from util.model_util import load_yolo_model # 模型加载工具
|
||||||
|
|
||||||
# 路径配置
|
# 路径配置
|
||||||
CURRENT_FILE_PATH = Path(__file__).resolve()
|
CURRENT_FILE_PATH = Path(__file__).resolve()
|
||||||
@ -28,14 +31,63 @@ DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep
|
|||||||
ALLOWED_MODEL_EXT = {"pt"}
|
ALLOWED_MODEL_EXT = {"pt"}
|
||||||
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
|
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
|
||||||
|
|
||||||
# 全局模型变量
|
# 全局模型变量(带版本标识)
|
||||||
global _yolo_model
|
global _yolo_model, _current_model_version
|
||||||
_yolo_model = None
|
_yolo_model = None
|
||||||
|
_current_model_version = None # 模型版本标识(用于检测模型是否变化)
|
||||||
|
|
||||||
router = APIRouter(prefix="/models", tags=["模型管理"])
|
router = APIRouter(prefix="/models", tags=["模型管理"])
|
||||||
|
|
||||||
|
|
||||||
# 工具函数:验证模型路径
|
# 服务重启核心工具函数
|
||||||
|
def restart_service():
|
||||||
|
"""重启当前FastAPI服务进程"""
|
||||||
|
print("\n[服务重启] 检测到默认模型更换,开始清理资源并重启...")
|
||||||
|
try:
|
||||||
|
# 关闭所有WebSocket连接
|
||||||
|
try:
|
||||||
|
from ws import connected_clients
|
||||||
|
if connected_clients:
|
||||||
|
print(f"[服务重启] 关闭{len(connected_clients)}个WebSocket旧连接")
|
||||||
|
for ip, conn in list(connected_clients.items()):
|
||||||
|
try:
|
||||||
|
if conn.consumer_task and not conn.consumer_task.done():
|
||||||
|
conn.consumer_task.cancel()
|
||||||
|
conn.websocket.close(code=1001, reason="模型更新,服务重启")
|
||||||
|
connected_clients.pop(ip)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[服务重启] 关闭客户端{ip}连接失败:{str(e)}")
|
||||||
|
except ImportError:
|
||||||
|
print("[服务重启] 未找到WebSocket连接管理模块,跳过连接关闭")
|
||||||
|
|
||||||
|
# 关闭数据库连接
|
||||||
|
if hasattr(db, "close_all_connections"):
|
||||||
|
db.close_all_connections()
|
||||||
|
else:
|
||||||
|
print("[警告] db模块未实现close_all_connections,可能存在连接泄漏")
|
||||||
|
|
||||||
|
# 启动新进程
|
||||||
|
python_exec = sys.executable
|
||||||
|
current_argv = sys.argv
|
||||||
|
print(f"[服务重启] 启动新进程:{python_exec} {' '.join(current_argv)}")
|
||||||
|
subprocess.Popen(
|
||||||
|
[python_exec] + current_argv,
|
||||||
|
close_fds=True,
|
||||||
|
start_new_session=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE
|
||||||
|
)
|
||||||
|
|
||||||
|
# 退出当前进程
|
||||||
|
print("[服务重启] 新进程已启动,当前进程退出")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[服务重启] 重启失败:{str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"模型更换成功,但服务重启失败:{str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
# 模型路径验证工具函数
|
||||||
def get_valid_model_abs_path(relative_path: str) -> str:
|
def get_valid_model_abs_path(relative_path: str) -> str:
|
||||||
try:
|
try:
|
||||||
relative_path = relative_path.replace("/", os.sep)
|
relative_path = relative_path.replace("/", os.sep)
|
||||||
@ -87,6 +139,49 @@ def get_valid_model_abs_path(relative_path: str) -> str:
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
# 对外提供当前模型(带版本校验)
|
||||||
|
def get_current_yolo_model():
|
||||||
|
"""供检测模块获取当前最新默认模型(仅版本变化时重新加载)"""
|
||||||
|
global _yolo_model, _current_model_version
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
cursor.execute("SELECT path FROM model WHERE is_default = 1")
|
||||||
|
default_model = cursor.fetchone()
|
||||||
|
if not default_model:
|
||||||
|
print("[get_current_yolo_model] 暂无默认模型")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 1. 计算当前默认模型的唯一版本标识
|
||||||
|
# (路径哈希 + 文件修改时间戳,确保模型变化时版本变化)
|
||||||
|
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
||||||
|
model_stat = os.stat(valid_abs_path)
|
||||||
|
model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
|
||||||
|
|
||||||
|
# 2. 版本未变化则复用已有模型(核心优化点)
|
||||||
|
if _yolo_model and _current_model_version == model_version:
|
||||||
|
# print(f"[get_current_yolo_model] 模型版本未变,复用缓存(版本:{_current_model_version[:10]}...)")
|
||||||
|
return _yolo_model
|
||||||
|
|
||||||
|
# 3. 版本变化时重新加载模型
|
||||||
|
_yolo_model = load_yolo_model(valid_abs_path)
|
||||||
|
if _yolo_model:
|
||||||
|
setattr(_yolo_model, "model_path", valid_abs_path)
|
||||||
|
_current_model_version = model_version # 更新版本标识
|
||||||
|
print(f"[get_current_yolo_model] 模型版本更新,重新加载(新版本:{model_version[:10]}...)")
|
||||||
|
else:
|
||||||
|
print(f"[get_current_yolo_model] 加载最新默认模型失败:{valid_abs_path}")
|
||||||
|
return _yolo_model
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[get_current_yolo_model] 加载失败:{str(e)}")
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# 1. 上传模型
|
# 1. 上传模型
|
||||||
@router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)")
|
@router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)")
|
||||||
async def upload_model(
|
async def upload_model(
|
||||||
@ -142,12 +237,16 @@ async def upload_model(
|
|||||||
if not new_model:
|
if not new_model:
|
||||||
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
|
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
|
||||||
|
|
||||||
# 加载默认模型
|
# 加载默认模型并更新版本
|
||||||
global _yolo_model
|
global _yolo_model, _current_model_version
|
||||||
if is_default:
|
if is_default:
|
||||||
valid_abs_path = get_valid_model_abs_path(db_relative_path)
|
valid_abs_path = get_valid_model_abs_path(db_relative_path)
|
||||||
_yolo_model = load_yolo_model(valid_abs_path)
|
_yolo_model = load_yolo_model(valid_abs_path)
|
||||||
if not _yolo_model:
|
if _yolo_model:
|
||||||
|
setattr(_yolo_model, "model_path", valid_abs_path)
|
||||||
|
model_stat = os.stat(valid_abs_path)
|
||||||
|
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
|
||||||
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path})"
|
detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path})"
|
||||||
@ -246,11 +345,15 @@ async def get_default_model():
|
|||||||
raise HTTPException(status_code=404, detail="暂无默认模型")
|
raise HTTPException(status_code=404, detail="暂无默认模型")
|
||||||
|
|
||||||
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
||||||
global _yolo_model
|
global _yolo_model, _current_model_version
|
||||||
|
|
||||||
if not _yolo_model:
|
if not _yolo_model:
|
||||||
_yolo_model = load_yolo_model(valid_abs_path)
|
_yolo_model = load_yolo_model(valid_abs_path)
|
||||||
if not _yolo_model:
|
if _yolo_model:
|
||||||
|
setattr(_yolo_model, "model_path", valid_abs_path)
|
||||||
|
model_stat = os.stat(valid_abs_path)
|
||||||
|
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
|
||||||
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"默认模型存在,但加载失败(路径:{valid_abs_path})"
|
detail=f"默认模型存在,但加载失败(路径:{valid_abs_path})"
|
||||||
@ -358,11 +461,16 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest):
|
|||||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||||
updated_model = cursor.fetchone()
|
updated_model = cursor.fetchone()
|
||||||
|
|
||||||
global _yolo_model
|
# 更新模型后重置版本标识
|
||||||
|
global _yolo_model, _current_model_version
|
||||||
if need_load_default:
|
if need_load_default:
|
||||||
valid_abs_path = get_valid_model_abs_path(updated_model["path"])
|
valid_abs_path = get_valid_model_abs_path(updated_model["path"])
|
||||||
_yolo_model = load_yolo_model(valid_abs_path)
|
_yolo_model = load_yolo_model(valid_abs_path)
|
||||||
if not _yolo_model:
|
if _yolo_model:
|
||||||
|
setattr(_yolo_model, "model_path", valid_abs_path)
|
||||||
|
model_stat = os.stat(valid_abs_path)
|
||||||
|
_current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}"
|
||||||
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path})"
|
detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path})"
|
||||||
@ -382,6 +490,96 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest):
|
|||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
|
# 5.1 更换默认模型(自动重启服务)
|
||||||
|
@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型(自动重启服务)")
|
||||||
|
async def set_default_model(model_id: int):
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
conn = db.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
conn.autocommit = False # 开启事务
|
||||||
|
|
||||||
|
# 1. 校验目标模型是否存在
|
||||||
|
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||||
|
target_model = cursor.fetchone()
|
||||||
|
if not target_model:
|
||||||
|
raise HTTPException(status_code=404, detail=f"目标模型不存在!ID:{model_id}")
|
||||||
|
|
||||||
|
# 2. 检查是否已为默认模型
|
||||||
|
if target_model["is_default"]:
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message=f"模型ID:{model_id} 已是默认模型,无需更换和重启",
|
||||||
|
data=ModelResponse(**target_model)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 校验目标模型文件合法性
|
||||||
|
try:
|
||||||
|
valid_abs_path = get_valid_model_abs_path(target_model["path"])
|
||||||
|
except HTTPException as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"目标模型文件非法,无法设为默认:{e.detail}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# 4. 数据库事务:更新默认模型状态
|
||||||
|
try:
|
||||||
|
cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP")
|
||||||
|
cursor.execute(
|
||||||
|
"UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s",
|
||||||
|
(model_id,)
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
except MySQLError as e:
|
||||||
|
conn.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"更新默认模型状态失败(已回滚):{str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# 5. 验证新模型可加载性
|
||||||
|
test_model = load_yolo_model(valid_abs_path)
|
||||||
|
if not test_model:
|
||||||
|
conn.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"新默认模型加载失败,已回滚状态(路径:{valid_abs_path})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. 重新查询更新后的模型信息
|
||||||
|
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||||
|
updated_model = cursor.fetchone()
|
||||||
|
|
||||||
|
# 7. 重置版本标识(关键:确保下次检测加载新模型)
|
||||||
|
global _current_model_version
|
||||||
|
_current_model_version = None
|
||||||
|
print(f"[更换默认模型] 已重置模型版本标识,下次检测将加载新模型")
|
||||||
|
|
||||||
|
# 8. 延迟重启服务
|
||||||
|
print(f"[更换默认模型] 成功!将在1秒后重启服务以应用新模型(ID:{model_id})")
|
||||||
|
threading.Timer(
|
||||||
|
interval=1.0,
|
||||||
|
function=restart_service
|
||||||
|
).start()
|
||||||
|
|
||||||
|
# 9. 返回成功响应
|
||||||
|
return APIResponse(
|
||||||
|
code=200,
|
||||||
|
message=f"已成功更换默认模型(ID:{model_id})!服务将在1秒后自动重启以应用新模型",
|
||||||
|
data=ModelResponse(**updated_model)
|
||||||
|
)
|
||||||
|
|
||||||
|
except MySQLError as e:
|
||||||
|
if conn:
|
||||||
|
conn.rollback()
|
||||||
|
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
conn.autocommit = True
|
||||||
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# 6. 删除模型
|
# 6. 删除模型
|
||||||
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
|
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
|
||||||
async def delete_model(model_id: int):
|
async def delete_model(model_id: int):
|
||||||
@ -420,10 +618,12 @@ async def delete_model(model_id: int):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
extra_msg = f"(文件删除失败:{str(e)})"
|
extra_msg = f"(文件删除失败:{str(e)})"
|
||||||
|
|
||||||
global _yolo_model
|
# 如果删除的是当前加载的模型,重置缓存
|
||||||
if _yolo_model and str(_yolo_model.model_path) == model_abs_path_str:
|
global _yolo_model, _current_model_version
|
||||||
|
if _yolo_model and str(getattr(_yolo_model, "model_path", "")) == model_abs_path_str:
|
||||||
_yolo_model = None
|
_yolo_model = None
|
||||||
print(f"[模型删除] 已清空全局模型(路径:{model_abs_path_str})")
|
_current_model_version = None
|
||||||
|
print(f"[模型删除] 已清空全局模型缓存(路径:{model_abs_path_str})")
|
||||||
|
|
||||||
return APIResponse(
|
return APIResponse(
|
||||||
code=200,
|
code=200,
|
||||||
@ -466,32 +666,3 @@ async def download_model(model_id: int):
|
|||||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||||
finally:
|
finally:
|
||||||
db.close_connection(conn, cursor)
|
db.close_connection(conn, cursor)
|
||||||
|
|
||||||
|
|
||||||
# 对外提供当前模型
|
|
||||||
def get_current_yolo_model():
|
|
||||||
"""供检测模块获取当前加载的模型"""
|
|
||||||
global _yolo_model
|
|
||||||
if not _yolo_model:
|
|
||||||
conn = None
|
|
||||||
cursor = None
|
|
||||||
try:
|
|
||||||
conn = db.get_connection()
|
|
||||||
cursor = conn.cursor(dictionary=True)
|
|
||||||
cursor.execute("SELECT path FROM model WHERE is_default = 1")
|
|
||||||
default_model = cursor.fetchone()
|
|
||||||
if not default_model:
|
|
||||||
print("[get_current_yolo_model] 暂无默认模型")
|
|
||||||
return None
|
|
||||||
|
|
||||||
valid_abs_path = get_valid_model_abs_path(default_model["path"])
|
|
||||||
_yolo_model = load_yolo_model(valid_abs_path)
|
|
||||||
if _yolo_model:
|
|
||||||
print(f"[get_current_yolo_model] 自动加载默认模型成功")
|
|
||||||
else:
|
|
||||||
print(f"[get_current_yolo_model] 自动加载默认模型失败")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[get_current_yolo_model] 加载失败:{str(e)}")
|
|
||||||
finally:
|
|
||||||
db.close_connection(conn, cursor)
|
|
||||||
return _yolo_model
|
|
||||||
|
@ -12,7 +12,7 @@ def save_face_to_up_images(
|
|||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径
|
保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径
|
||||||
修复路径计算错误,确保所有路径在up_images根目录下
|
修复路径计算错误,确保所有路径在up_images根目录下,且统一使用正斜杠
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
client_ip: 客户端IP(原始格式,如192.168.1.101)
|
client_ip: 客户端IP(原始格式,如192.168.1.101)
|
||||||
@ -38,7 +38,7 @@ def save_face_to_up_images(
|
|||||||
safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|']) # 过滤非法字符
|
safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|']) # 过滤非法字符
|
||||||
|
|
||||||
# 3. 构建根目录(强制转为绝对路径,避免相对路径混淆)
|
# 3. 构建根目录(强制转为绝对路径,避免相对路径混淆)
|
||||||
root_dir = Path("up_images").resolve() # 转为绝对路径(关键修复!)
|
root_dir = Path("up_images").resolve() # 转为绝对路径
|
||||||
if not root_dir.exists():
|
if not root_dir.exists():
|
||||||
root_dir.mkdir(parents=True, exist_ok=True)
|
root_dir.mkdir(parents=True, exist_ok=True)
|
||||||
print(f"[FileUtil] 已创建up_images根目录:{root_dir}")
|
print(f"[FileUtil] 已创建up_images根目录:{root_dir}")
|
||||||
@ -53,15 +53,16 @@ def save_face_to_up_images(
|
|||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
|
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3]
|
||||||
image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}"
|
image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}"
|
||||||
|
|
||||||
# 6. 计算路径(关键修复:确保所有路径都是绝对路径且在root_dir下)
|
# 6. 计算路径(确保所有路径都是绝对路径且在root_dir下)
|
||||||
local_abs_path = face_name_dir / image_filename # 绝对路径
|
local_abs_path = face_name_dir / image_filename # 绝对路径
|
||||||
|
|
||||||
# 验证路径是否在root_dir下(防止路径穿越攻击)
|
# 验证路径是否在root_dir下(防止路径穿越攻击)
|
||||||
if not local_abs_path.resolve().is_relative_to(root_dir.resolve()):
|
if not local_abs_path.resolve().is_relative_to(root_dir.resolve()):
|
||||||
raise Exception(f"图片路径不在up_images根目录下(安全校验失败):{local_abs_path}")
|
raise Exception(f"图片路径不在up_images根目录下(安全校验失败):{local_abs_path}")
|
||||||
|
|
||||||
# 数据库存储路径:从root_dir开始的相对路径(如 up_images/192_168_110_31/小王/xxx.jpg)
|
# 数据库存储路径:从root_dir开始的相对路径,强制替换为正斜杠
|
||||||
db_path = str(root_dir.name / local_abs_path.relative_to(root_dir))
|
relative_path = local_abs_path.relative_to(root_dir)
|
||||||
|
db_path = str(relative_path).replace("\\", "/") # 关键修复:统一使用正斜杠
|
||||||
|
|
||||||
# 7. 写入图片文件
|
# 7. 写入图片文件
|
||||||
with open(local_abs_path, "wb") as f:
|
with open(local_abs_path, "wb") as f:
|
||||||
@ -72,7 +73,7 @@ def save_face_to_up_images(
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"db_path": db_path, # 存数据库的相对路径(up_images开头)
|
"db_path": db_path, # 存数据库的相对路径(使用正斜杠)
|
||||||
"local_abs_path": str(local_abs_path), # 本地绝对路径
|
"local_abs_path": str(local_abs_path), # 本地绝对路径
|
||||||
"msg": "图片保存成功"
|
"msg": "图片保存成功"
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user