From 206652d6bb53dd1cf8cccfcbb1f970fd09b415f5 Mon Sep 17 00:00:00 2001 From: ninghongbin <2409766686@qq.com> Date: Fri, 12 Sep 2025 18:28:43 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=AF=E4=BB=A5=E6=88=90=E5=8A=9F=E5=8A=A8?= =?UTF-8?q?=E6=80=81=E6=9B=B4=E6=8D=A2yolo=E6=A8=A1=E5=9E=8B=E5=B9=B6?= =?UTF-8?q?=E9=87=8D=E5=90=AF=E6=9C=8D=E5=8A=A1=E7=94=9F=E6=95=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- core/yolo.py | 48 ++++--- ds/db.py | 14 ++ schema/device_schema.py | 34 +++-- service/device_service.py | 246 ++++++++++++++++++++++++++++++----- service/model_service.py | 265 +++++++++++++++++++++++++++++++------- util/file_util.py | 15 ++- 6 files changed, 499 insertions(+), 123 deletions(-) diff --git a/core/yolo.py b/core/yolo.py index d7c299e..924cc01 100644 --- a/core/yolo.py +++ b/core/yolo.py @@ -1,43 +1,35 @@ import os import numpy as np from ultralytics import YOLO -from service.model_service import get_current_yolo_model # 从模型管理模块获取模型 - -# 全局模型变量 -_yolo_model = None +from service.model_service import get_current_yolo_model # 带版本校验的模型获取 def load_model(model_path=None): - """加载YOLO模型(优先使用模型管理模块的默认模型)""" - global _yolo_model - + """加载YOLO模型(优先使用带版本校验的默认模型)""" if model_path is None: - _yolo_model = get_current_yolo_model() - return _yolo_model is not None - + # 调用带版本校验的模型获取函数(自动判断是否需要重新加载) + return get_current_yolo_model() try: - _yolo_model = YOLO(model_path) - return True + # 加载指定路径模型(用于特殊场景) + return YOLO(model_path) except Exception as e: print(f"YOLO模型加载失败(指定路径):{str(e)}") - return False + return None -def detect(frame, conf_threshold=0.2): - """执行目标检测,返回(是否成功, 结果字符串)""" - global _yolo_model - - # 确保模型已加载 - if not _yolo_model: - if not load_model(): - return (False, "模型未初始化") +def detect(frame, conf_threshold=0.7): + """执行目标检测(仅模型版本变化时重新加载,平时复用缓存)""" + # 获取模型(内部已做版本校验,未变化则直接返回缓存) + current_model = load_model() + if not current_model: + return (False, "未加载到最新YOLO模型") if frame is None: return (False, "无效输入帧") try: - # 执行检测(frame应为numpy数组) - results = _yolo_model(frame, conf=conf_threshold, verbose=False) + # 用当前模型执行检测(复用缓存,无额外加载耗时) + results = current_model(frame, conf=conf_threshold, verbose=False) has_results = len(results[0].boxes) > 0 if results else False if not has_results: @@ -49,11 +41,17 @@ def detect(frame, conf_threshold=0.2): cls = int(box.cls[0]) conf = float(box.conf[0]) bbox = [round(x, 2) for x in box.xyxy[0].tolist()] # 保留两位小数 - class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}" + # 从当前模型中获取类别名(确保与模型匹配) + class_name = current_model.names[cls] if hasattr(current_model, 'names') else f"类别{cls}" result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox})") + # 打印当前使用的模型路径和版本(用于验证) + # model_path = getattr(current_model, "model_path", "未知路径") + # from service.model_service import _current_model_version + # print(f"[YOLO检测] 使用模型:{model_path}(版本:{_current_model_version[:10]}...)") + return (True, "; ".join(result_parts)) except Exception as e: - print(f"检测过程出错:{str(e)}") + print(f"YOLO检测过程出错:{str(e)}") return (False, f"检测错误:{str(e)}") diff --git a/ds/db.py b/ds/db.py index ae891b7..f06c09c 100644 --- a/ds/db.py +++ b/ds/db.py @@ -3,6 +3,8 @@ from mysql.connector import Error from .config import MYSQL_CONFIG +# 关键:声明类级别的连接池实例(必须有这一行!) +_connection_pool = None # 确保这一行存在,且拼写正确 class Database: """MySQL 连接池管理类""" @@ -41,6 +43,18 @@ class Database: except Error as e: raise Exception(f"MySQL 连接关闭失败: {str(e)}") from e + @classmethod + def close_all_connections(cls): + """清理连接池(服务重启前调用)""" + try: + # 先检查属性是否存在,再判断是否有值 + if hasattr(cls, "_connection_pool") and cls._connection_pool: + cls._connection_pool = None # 重置连接池 + print("[Database] 连接池已重置,旧连接将被自动清理") + else: + print("[Database] 连接池未初始化或已重置,无需操作") + except Exception as e: + print(f"[Database] 重置连接池失败: {str(e)}") # 暴露数据库操作工具 db = Database() diff --git a/schema/device_schema.py b/schema/device_schema.py index adf7e53..f3af2ab 100644 --- a/schema/device_schema.py +++ b/schema/device_schema.py @@ -8,32 +8,48 @@ from pydantic import BaseModel, Field # 请求模型 # ------------------------------ class DeviceCreateRequest(BaseModel): - """设备流信息创建请求模型(与数据库表字段对齐)""" + """设备创建请求模型""" ip: Optional[str] = Field(..., max_length=100, description="设备IP地址") hostname: Optional[str] = Field(None, max_length=100, description="设备别名") - params: Optional[Dict] = Field(None, description="设备详细信息(JSON格式)") + params: Optional[Dict] = Field(None, description="设备扩展参数(JSON格式)") # ------------------------------ -# 响应模型(后端返回数据)- 严格对齐数据库表字段 +# 响应模型 # ------------------------------ class DeviceResponse(BaseModel): - """设备流信息响应模型(与数据库表字段完全一致)""" + """单设备信息响应模型(与数据库表字段对齐)""" id: int = Field(..., description="设备主键ID") client_ip: Optional[str] = Field(None, max_length=100, description="设备IP地址") hostname: Optional[str] = Field(None, max_length=100, description="设备别名") - device_online_status: int = Field(..., description="设备在线状态(1-在线、0-离线)") + device_online_status: int = Field(..., description="在线状态(1-在线、0-离线)") device_type: Optional[str] = Field(None, description="设备类型") alarm_count: int = Field(..., description="报警次数") - params: Optional[str] = Field(None, description="设备详细信息(JSON字符串)") + params: Optional[str] = Field(None, description="扩展参数(JSON字符串)") created_at: datetime = Field(..., description="记录创建时间") updated_at: datetime = Field(..., description="记录更新时间") - # 支持从数据库查询结果直接转换 - model_config = {"from_attributes": True} + model_config = {"from_attributes": True} # 支持从数据库结果直接转换 class DeviceListResponse(BaseModel): - """设备流信息列表响应模型""" + """设备列表响应模型""" total: int = Field(..., description="设备总数") devices: List[DeviceResponse] = Field(..., description="设备列表") + + +class DeviceStatusHistoryResponse(BaseModel): + """设备上下线记录响应模型""" + id: int = Field(..., description="记录ID") + device_id: int = Field(..., description="关联设备ID") + client_ip: Optional[str] = Field(None, description="设备IP地址") + status: int = Field(..., description="状态(1-在线、0-离线)") + status_time: datetime = Field(..., description="状态变更时间") + + model_config = {"from_attributes": True} + + +class DeviceStatusHistoryListResponse(BaseModel): + """设备上下线记录列表响应模型""" + total: int = Field(..., description="记录总数") + history: List[DeviceStatusHistoryResponse] = Field(..., description="上下线记录列表") \ No newline at end of file diff --git a/service/device_service.py b/service/device_service.py index a5f15fb..248a2ba 100644 --- a/service/device_service.py +++ b/service/device_service.py @@ -1,10 +1,14 @@ import json +from datetime import datetime, date -from fastapi import APIRouter, Query, HTTPException,Request +from fastapi import APIRouter, Query, HTTPException, Request, Path from mysql.connector import Error as MySQLError from ds.db import db -from schema.device_schema import DeviceCreateRequest, DeviceResponse, DeviceListResponse +from schema.device_schema import ( + DeviceCreateRequest, DeviceResponse, DeviceListResponse, + DeviceStatusHistoryResponse, DeviceStatusHistoryListResponse +) from schema.response_schema import APIResponse router = APIRouter( @@ -13,12 +17,53 @@ router = APIRouter( ) +# ------------------------------ +# 内部工具方法 - 记录设备状态变更历史 +# ------------------------------ +def record_status_change(client_ip: str, status: int) -> bool: + """ + 记录设备状态变更历史(写入 device_action 表) + + :param client_ip: 设备IP + :param status: 状态(1-在线、0-离线) + :return: 操作是否成功 + """ + if not client_ip: + raise ValueError("客户端IP不能为空") + + if status not in (0, 1): + raise ValueError("状态必须是0(离线)或1(在线)") + + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) + + # 插入状态变更记录到 device_action + insert_query = """ + INSERT INTO device_action + (client_ip, action, created_at, updated_at) + VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + """ + cursor.execute(insert_query, (client_ip, status)) + conn.commit() + + return True + except MySQLError as e: + if conn: + conn.rollback() + raise Exception(f"记录设备状态变更失败: {str(e)}") from e + finally: + db.close_connection(conn, cursor) + + # ------------------------------ # 内部工具方法 - 通过客户端IP增加设备报警次数 # ------------------------------ def increment_alarm_count_by_ip(client_ip: str) -> bool: """ - 通过客户端IP增加设备的报警次数(内部服务方法) + 通过客户端IP增加设备的报警次数 :param client_ip: 客户端IP地址 :return: 操作是否成功 @@ -34,7 +79,8 @@ def increment_alarm_count_by_ip(client_ip: str) -> bool: # 检查设备是否存在 cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,)) - if not cursor.fetchone(): + device = cursor.fetchone() + if not device: raise ValueError(f"客户端IP为 {client_ip} 的设备不存在") # 报警次数加1、并更新时间戳 @@ -61,7 +107,7 @@ def increment_alarm_count_by_ip(client_ip: str) -> bool: # ------------------------------ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: """ - 通过客户端IP更新设备的在线状态(内部服务方法) + 通过客户端IP更新设备的在线状态 :param client_ip: 客户端IP地址 :param online_status: 在线状态(1-在线、0-离线) @@ -70,7 +116,6 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: if not client_ip: raise ValueError("客户端IP不能为空") - # 验证状态值有效性 if online_status not in (0, 1): raise ValueError("在线状态必须是0(离线)或1(在线)") @@ -80,11 +125,16 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: conn = db.get_connection() cursor = conn.cursor(dictionary=True) - # 检查设备是否存在 - cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,)) - if not cursor.fetchone(): + # 检查设备是否存在并获取设备ID + cursor.execute("SELECT id, device_online_status FROM devices WHERE client_ip = %s", (client_ip,)) + device = cursor.fetchone() + if not device: raise ValueError(f"客户端IP为 {client_ip} 的设备不存在") + # 状态无变化则不操作 + if device['device_online_status'] == online_status: + return True + # 更新在线状态和时间戳 update_query = """ UPDATE devices @@ -93,8 +143,11 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: WHERE client_ip = %s """ cursor.execute(update_query, (online_status, client_ip)) - conn.commit() + # 记录状态变更历史 + record_status_change(client_ip, online_status) + + conn.commit() return True except MySQLError as e: if conn: @@ -105,46 +158,43 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: # ------------------------------ -# 原有接口保持不变 +# 创建设备信息接口 # ------------------------------ @router.post("/add", response_model=APIResponse, summary="创建设备信息") -async def create_device(device_data: DeviceCreateRequest, request: Request): # 注入Request对象 - # 原有代码保持不变 +async def create_device(device_data: DeviceCreateRequest, request: Request): conn = None cursor = None try: conn = db.get_connection() cursor = conn.cursor(dictionary=True) + # 检查设备是否已存在 cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,)) existing_device = cursor.fetchone() if existing_device: - # 更新设备状态为在线 + # 更新设备为在线状态 update_online_status_by_ip(client_ip=device_data.ip, online_status=1) - # 返回信息 return APIResponse( code=200, message=f"设备IP {device_data.ip} 已存在、返回已有设备信息", - data=DeviceResponse(** existing_device) + data=DeviceResponse(**existing_device) ) - # 直接使用注入的request对象获取用户代理 + # 通过 User-Agent 判断设备类型 user_agent = request.headers.get("User-Agent", "").lower() - + device_type = "unknown" if user_agent == "default": - device_type = device_data.params.get("os") if ( - device_data.params and isinstance(device_data.params, dict)) else "unknown" + device_type = device_data.params.get("os") if (device_data.params and isinstance(device_data.params, dict)) else "unknown" elif "windows" in user_agent: device_type = "windows" elif "android" in user_agent: device_type = "android" elif "linux" in user_agent: device_type = "linux" - else: - device_type = "unknown" device_params_json = json.dumps(device_data.params) if device_data.params else None + # 插入新设备 insert_query = """ INSERT INTO devices (client_ip, hostname, device_online_status, device_type, alarm_count, params) @@ -160,10 +210,14 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): # )) conn.commit() + # 获取新设备并返回 device_id = cursor.lastrowid cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,)) new_device = cursor.fetchone() + # 记录上线历史 + record_status_change(device_data.ip, 1) + return APIResponse( code=200, message="设备创建成功", @@ -175,7 +229,7 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): # conn.rollback() raise Exception(f"创建设备失败: {str(e)}") from e except json.JSONDecodeError as e: - raise Exception(f"设备详细信息JSON序列化失败: {str(e)}") from e + raise Exception(f"设备参数JSON序列化失败: {str(e)}") from e except Exception as e: if conn: conn.rollback() @@ -183,14 +237,17 @@ async def create_device(device_data: DeviceCreateRequest, request: Request): # finally: db.close_connection(conn, cursor) + +# ------------------------------ +# 获取设备列表接口 +# ------------------------------ @router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)") async def get_device_list( - page: int = Query(1, ge=1, description="页码、默认第1页"), - page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"), + page: int = Query(1, ge=1, description="页码,默认第1页"), + page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"), device_type: str = Query(None, description="按设备类型筛选"), online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选") ): - # 原有代码保持不变 conn = None cursor = None try: @@ -207,12 +264,14 @@ async def get_device_list( where_clause.append("device_online_status = %s") params.append(online_status) + # 统计总数 count_query = "SELECT COUNT(*) AS total FROM devices" if where_clause: count_query += " WHERE " + " AND ".join(where_clause) cursor.execute(count_query, params) total = cursor.fetchone()["total"] + # 分页查询列表 offset = (page - 1) * page_size list_query = "SELECT * FROM devices" if where_clause: @@ -238,23 +297,140 @@ async def get_device_list( db.close_connection(conn, cursor) -def get_unique_client_ips() -> list[str]: - """ - 获取所有去重的客户端IP列表 - - :return: 去重后的客户端IP字符串列表,如果没有数据则返回空列表 - """ +# ------------------------------ +# 获取设备上下线记录接口 +# ------------------------------ +@router.get("/{device_id}/status-history", response_model=APIResponse, summary="获取设备上下线记录") +async def get_device_status_history( + device_id: int = Path(..., description="设备ID"), + page: int = Query(1, ge=1, description="页码,默认第1页"), + page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"), + start_date: date = Query(None, description="开始日期,格式YYYY-MM-DD"), + end_date: date = Query(None, description="结束日期,格式YYYY-MM-DD") +): + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) + + # 检查设备是否存在并获取 client_ip + cursor.execute("SELECT id, client_ip FROM devices WHERE id = %s", (device_id,)) + device = cursor.fetchone() + if not device: + raise HTTPException(status_code=404, detail=f"设备ID为 {device_id} 的设备不存在") + client_ip = device['client_ip'] + + where_clause = ["client_ip = %s"] + params = [client_ip] + + # 日期筛选 + if start_date: + where_clause.append("DATE(created_at) >= %s") + params.append(start_date.strftime("%Y-%m-%d")) + if end_date: + where_clause.append("DATE(created_at) <= %s") + params.append(end_date.strftime("%Y-%m-%d")) + + # 统计记录总数 + count_query = "SELECT COUNT(*) AS total FROM device_action WHERE " + " AND ".join(where_clause) + cursor.execute(count_query, params) + total = cursor.fetchone()["total"] + + # 分页查询记录 + offset = (page - 1) * page_size + list_query = f""" + SELECT * FROM device_action + WHERE {' AND '.join(where_clause)} + ORDER BY created_at DESC + LIMIT %s OFFSET %s + """ + params.extend([page_size, offset]) + cursor.execute(list_query, params) + history_list = cursor.fetchall() + + # 格式化为响应模型结构 + formatted_history = [] + for item in history_list: + formatted_item = { + "id": item["id"], + "device_id": device_id, + "client_ip": item["client_ip"], + "status": item["action"], + "status_time": item["created_at"] + } + formatted_history.append(formatted_item) + + return APIResponse( + code=200, + message="获取设备上下线记录成功", + data=DeviceStatusHistoryListResponse( + total=total, + history=[DeviceStatusHistoryResponse(**item) for item in formatted_history] + ) + ) + + except MySQLError as e: + raise Exception(f"获取设备上下线记录失败: {str(e)}") from e + finally: + db.close_connection(conn, cursor) + + +# ------------------------------ +# 手动更新设备在线状态接口 +# ------------------------------ +@router.put("/{device_id}/status", response_model=APIResponse, summary="更新设备在线状态") +async def update_device_status( + device_id: int = Path(..., description="设备ID"), + status: int = Query(..., ge=0, le=1, description="在线状态(1-在线、0-离线)") +): + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) + + # 获取设备 client_ip + cursor.execute("SELECT id, client_ip FROM devices WHERE id = %s", (device_id,)) + device = cursor.fetchone() + if not device: + raise HTTPException(status_code=404, detail=f"设备ID为 {device_id} 的设备不存在") + + # 更新状态 + success = update_online_status_by_ip(device['client_ip'], status) + + if success: + status_text = "在线" if status == 1 else "离线" + return APIResponse( + code=200, + message=f"设备已更新为{status_text}状态", + data={"device_id": device_id, "status": status, "status_text": status_text} + ) + return APIResponse( + code=500, + message="更新设备状态失败", + data=None + ) + + except MySQLError as e: + raise Exception(f"更新设备状态失败: {str(e)}") from e + finally: + db.close_connection(conn, cursor) + + +# ------------------------------ +# 获取所有去重的客户端IP列表 +# ------------------------------ +def get_unique_client_ips() -> list[str]: + """获取所有去重的客户端IP列表""" conn = None cursor = None try: conn = db.get_connection() cursor = conn.cursor(dictionary=True) - # 查询去重的客户端IP query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL" cursor.execute(query) - - # 提取结果并转换为字符串列表 results = cursor.fetchall() return [item['client_ip'] for item in results] diff --git a/service/model_service.py b/service/model_service.py index 8921c4a..599920e 100644 --- a/service/model_service.py +++ b/service/model_service.py @@ -1,10 +1,13 @@ +import subprocess +import os +import sys +import shutil +import threading +from pathlib import Path +from datetime import datetime from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query from fastapi.responses import FileResponse from mysql.connector import Error as MySQLError -import os -import shutil -from pathlib import Path -from datetime import datetime # 复用项目依赖 from ds.db import db @@ -15,7 +18,7 @@ from schema.model_schema import ( ModelListResponse ) from schema.response_schema import APIResponse -from util.model_util import load_yolo_model # 使用修复后的模型加载工具 +from util.model_util import load_yolo_model # 模型加载工具 # 路径配置 CURRENT_FILE_PATH = Path(__file__).resolve() @@ -28,14 +31,63 @@ DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep ALLOWED_MODEL_EXT = {"pt"} MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB -# 全局模型变量 -global _yolo_model +# 全局模型变量(带版本标识) +global _yolo_model, _current_model_version _yolo_model = None +_current_model_version = None # 模型版本标识(用于检测模型是否变化) router = APIRouter(prefix="/models", tags=["模型管理"]) -# 工具函数:验证模型路径 +# 服务重启核心工具函数 +def restart_service(): + """重启当前FastAPI服务进程""" + print("\n[服务重启] 检测到默认模型更换,开始清理资源并重启...") + try: + # 关闭所有WebSocket连接 + try: + from ws import connected_clients + if connected_clients: + print(f"[服务重启] 关闭{len(connected_clients)}个WebSocket旧连接") + for ip, conn in list(connected_clients.items()): + try: + if conn.consumer_task and not conn.consumer_task.done(): + conn.consumer_task.cancel() + conn.websocket.close(code=1001, reason="模型更新,服务重启") + connected_clients.pop(ip) + except Exception as e: + print(f"[服务重启] 关闭客户端{ip}连接失败:{str(e)}") + except ImportError: + print("[服务重启] 未找到WebSocket连接管理模块,跳过连接关闭") + + # 关闭数据库连接 + if hasattr(db, "close_all_connections"): + db.close_all_connections() + else: + print("[警告] db模块未实现close_all_connections,可能存在连接泄漏") + + # 启动新进程 + python_exec = sys.executable + current_argv = sys.argv + print(f"[服务重启] 启动新进程:{python_exec} {' '.join(current_argv)}") + subprocess.Popen( + [python_exec] + current_argv, + close_fds=True, + start_new_session=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + + # 退出当前进程 + print("[服务重启] 新进程已启动,当前进程退出") + sys.exit(0) + + except Exception as e: + print(f"[服务重启] 重启失败:{str(e)}") + raise HTTPException(status_code=500, detail=f"模型更换成功,但服务重启失败:{str(e)}") from e + + +# 模型路径验证工具函数 def get_valid_model_abs_path(relative_path: str) -> str: try: relative_path = relative_path.replace("/", os.sep) @@ -87,6 +139,49 @@ def get_valid_model_abs_path(relative_path: str) -> str: ) from e +# 对外提供当前模型(带版本校验) +def get_current_yolo_model(): + """供检测模块获取当前最新默认模型(仅版本变化时重新加载)""" + global _yolo_model, _current_model_version + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) + cursor.execute("SELECT path FROM model WHERE is_default = 1") + default_model = cursor.fetchone() + if not default_model: + print("[get_current_yolo_model] 暂无默认模型") + return None + + # 1. 计算当前默认模型的唯一版本标识 + # (路径哈希 + 文件修改时间戳,确保模型变化时版本变化) + valid_abs_path = get_valid_model_abs_path(default_model["path"]) + model_stat = os.stat(valid_abs_path) + model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" + + # 2. 版本未变化则复用已有模型(核心优化点) + if _yolo_model and _current_model_version == model_version: + # print(f"[get_current_yolo_model] 模型版本未变,复用缓存(版本:{_current_model_version[:10]}...)") + return _yolo_model + + # 3. 版本变化时重新加载模型 + _yolo_model = load_yolo_model(valid_abs_path) + if _yolo_model: + setattr(_yolo_model, "model_path", valid_abs_path) + _current_model_version = model_version # 更新版本标识 + print(f"[get_current_yolo_model] 模型版本更新,重新加载(新版本:{model_version[:10]}...)") + else: + print(f"[get_current_yolo_model] 加载最新默认模型失败:{valid_abs_path}") + return _yolo_model + + except Exception as e: + print(f"[get_current_yolo_model] 加载失败:{str(e)}") + return None + finally: + db.close_connection(conn, cursor) + + # 1. 上传模型 @router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)") async def upload_model( @@ -142,12 +237,16 @@ async def upload_model( if not new_model: raise HTTPException(status_code=500, detail="上传成功但无法获取记录") - # 加载默认模型 - global _yolo_model + # 加载默认模型并更新版本 + global _yolo_model, _current_model_version if is_default: valid_abs_path = get_valid_model_abs_path(db_relative_path) _yolo_model = load_yolo_model(valid_abs_path) - if not _yolo_model: + if _yolo_model: + setattr(_yolo_model, "model_path", valid_abs_path) + model_stat = os.stat(valid_abs_path) + _current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" + else: raise HTTPException( status_code=500, detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path})" @@ -246,11 +345,15 @@ async def get_default_model(): raise HTTPException(status_code=404, detail="暂无默认模型") valid_abs_path = get_valid_model_abs_path(default_model["path"]) - global _yolo_model + global _yolo_model, _current_model_version if not _yolo_model: _yolo_model = load_yolo_model(valid_abs_path) - if not _yolo_model: + if _yolo_model: + setattr(_yolo_model, "model_path", valid_abs_path) + model_stat = os.stat(valid_abs_path) + _current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" + else: raise HTTPException( status_code=500, detail=f"默认模型存在,但加载失败(路径:{valid_abs_path})" @@ -358,11 +461,16 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest): cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) updated_model = cursor.fetchone() - global _yolo_model + # 更新模型后重置版本标识 + global _yolo_model, _current_model_version if need_load_default: valid_abs_path = get_valid_model_abs_path(updated_model["path"]) _yolo_model = load_yolo_model(valid_abs_path) - if not _yolo_model: + if _yolo_model: + setattr(_yolo_model, "model_path", valid_abs_path) + model_stat = os.stat(valid_abs_path) + _current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" + else: raise HTTPException( status_code=500, detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path})" @@ -382,6 +490,96 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest): db.close_connection(conn, cursor) +# 5.1 更换默认模型(自动重启服务) +@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型(自动重启服务)") +async def set_default_model(model_id: int): + conn = None + cursor = None + try: + conn = db.get_connection() + cursor = conn.cursor(dictionary=True) + conn.autocommit = False # 开启事务 + + # 1. 校验目标模型是否存在 + cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) + target_model = cursor.fetchone() + if not target_model: + raise HTTPException(status_code=404, detail=f"目标模型不存在!ID:{model_id}") + + # 2. 检查是否已为默认模型 + if target_model["is_default"]: + return APIResponse( + code=200, + message=f"模型ID:{model_id} 已是默认模型,无需更换和重启", + data=ModelResponse(**target_model) + ) + + # 3. 校验目标模型文件合法性 + try: + valid_abs_path = get_valid_model_abs_path(target_model["path"]) + except HTTPException as e: + raise HTTPException( + status_code=400, + detail=f"目标模型文件非法,无法设为默认:{e.detail}" + ) from e + + # 4. 数据库事务:更新默认模型状态 + try: + cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP") + cursor.execute( + "UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s", + (model_id,) + ) + conn.commit() + except MySQLError as e: + conn.rollback() + raise HTTPException( + status_code=500, + detail=f"更新默认模型状态失败(已回滚):{str(e)}" + ) from e + + # 5. 验证新模型可加载性 + test_model = load_yolo_model(valid_abs_path) + if not test_model: + conn.rollback() + raise HTTPException( + status_code=500, + detail=f"新默认模型加载失败,已回滚状态(路径:{valid_abs_path})" + ) + + # 6. 重新查询更新后的模型信息 + cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) + updated_model = cursor.fetchone() + + # 7. 重置版本标识(关键:确保下次检测加载新模型) + global _current_model_version + _current_model_version = None + print(f"[更换默认模型] 已重置模型版本标识,下次检测将加载新模型") + + # 8. 延迟重启服务 + print(f"[更换默认模型] 成功!将在1秒后重启服务以应用新模型(ID:{model_id})") + threading.Timer( + interval=1.0, + function=restart_service + ).start() + + # 9. 返回成功响应 + return APIResponse( + code=200, + message=f"已成功更换默认模型(ID:{model_id})!服务将在1秒后自动重启以应用新模型", + data=ModelResponse(**updated_model) + ) + + except MySQLError as e: + if conn: + conn.rollback() + raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e + finally: + if conn: + conn.autocommit = True + db.close_connection(conn, cursor) + + # 6. 删除模型 @router.delete("/{model_id}", response_model=APIResponse, summary="删除模型") async def delete_model(model_id: int): @@ -420,10 +618,12 @@ async def delete_model(model_id: int): except Exception as e: extra_msg = f"(文件删除失败:{str(e)})" - global _yolo_model - if _yolo_model and str(_yolo_model.model_path) == model_abs_path_str: + # 如果删除的是当前加载的模型,重置缓存 + global _yolo_model, _current_model_version + if _yolo_model and str(getattr(_yolo_model, "model_path", "")) == model_abs_path_str: _yolo_model = None - print(f"[模型删除] 已清空全局模型(路径:{model_abs_path_str})") + _current_model_version = None + print(f"[模型删除] 已清空全局模型缓存(路径:{model_abs_path_str})") return APIResponse( code=200, @@ -466,32 +666,3 @@ async def download_model(model_id: int): raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e finally: db.close_connection(conn, cursor) - - -# 对外提供当前模型 -def get_current_yolo_model(): - """供检测模块获取当前加载的模型""" - global _yolo_model - if not _yolo_model: - conn = None - cursor = None - try: - conn = db.get_connection() - cursor = conn.cursor(dictionary=True) - cursor.execute("SELECT path FROM model WHERE is_default = 1") - default_model = cursor.fetchone() - if not default_model: - print("[get_current_yolo_model] 暂无默认模型") - return None - - valid_abs_path = get_valid_model_abs_path(default_model["path"]) - _yolo_model = load_yolo_model(valid_abs_path) - if _yolo_model: - print(f"[get_current_yolo_model] 自动加载默认模型成功") - else: - print(f"[get_current_yolo_model] 自动加载默认模型失败") - except Exception as e: - print(f"[get_current_yolo_model] 加载失败:{str(e)}") - finally: - db.close_connection(conn, cursor) - return _yolo_model diff --git a/util/file_util.py b/util/file_util.py index 8e60e9d..2305ed2 100644 --- a/util/file_util.py +++ b/util/file_util.py @@ -12,7 +12,7 @@ def save_face_to_up_images( ) -> Dict[str, str]: """ 保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径 - 修复路径计算错误,确保所有路径在up_images根目录下 + 修复路径计算错误,确保所有路径在up_images根目录下,且统一使用正斜杠 参数: client_ip: 客户端IP(原始格式,如192.168.1.101) @@ -38,7 +38,7 @@ def save_face_to_up_images( safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|']) # 过滤非法字符 # 3. 构建根目录(强制转为绝对路径,避免相对路径混淆) - root_dir = Path("up_images").resolve() # 转为绝对路径(关键修复!) + root_dir = Path("up_images").resolve() # 转为绝对路径 if not root_dir.exists(): root_dir.mkdir(parents=True, exist_ok=True) print(f"[FileUtil] 已创建up_images根目录:{root_dir}") @@ -53,15 +53,16 @@ def save_face_to_up_images( timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3] image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}" - # 6. 计算路径(关键修复:确保所有路径都是绝对路径且在root_dir下) + # 6. 计算路径(确保所有路径都是绝对路径且在root_dir下) local_abs_path = face_name_dir / image_filename # 绝对路径 # 验证路径是否在root_dir下(防止路径穿越攻击) if not local_abs_path.resolve().is_relative_to(root_dir.resolve()): raise Exception(f"图片路径不在up_images根目录下(安全校验失败):{local_abs_path}") - # 数据库存储路径:从root_dir开始的相对路径(如 up_images/192_168_110_31/小王/xxx.jpg) - db_path = str(root_dir.name / local_abs_path.relative_to(root_dir)) + # 数据库存储路径:从root_dir开始的相对路径,强制替换为正斜杠 + relative_path = local_abs_path.relative_to(root_dir) + db_path = str(relative_path).replace("\\", "/") # 关键修复:统一使用正斜杠 # 7. 写入图片文件 with open(local_abs_path, "wb") as f: @@ -72,7 +73,7 @@ def save_face_to_up_images( return { "success": True, - "db_path": db_path, # 存数据库的相对路径(up_images开头) + "db_path": db_path, # 存数据库的相对路径(使用正斜杠) "local_abs_path": str(local_abs_path), # 本地绝对路径 "msg": "图片保存成功" } @@ -80,4 +81,4 @@ def save_face_to_up_images( except Exception as e: error_msg = f"图片保存失败:{str(e)}" print(f"[FileUtil] 错误:{error_msg}") - return {"success": False, "db_path": "", "local_abs_path": "", "msg": error_msg} \ No newline at end of file + return {"success": False, "db_path": "", "local_abs_path": "", "msg": error_msg}