可以成功动态更换yolo模型并重启服务生效
This commit is contained in:
		
							
								
								
									
										48
									
								
								core/yolo.py
									
									
									
									
									
								
							
							
						
						
									
										48
									
								
								core/yolo.py
									
									
									
									
									
								
							| @ -1,43 +1,35 @@ | ||||
| import os | ||||
| import numpy as np | ||||
| from ultralytics import YOLO | ||||
| from service.model_service import get_current_yolo_model  # 从模型管理模块获取模型 | ||||
|  | ||||
| # 全局模型变量 | ||||
| _yolo_model = None | ||||
| from service.model_service import get_current_yolo_model  # 带版本校验的模型获取 | ||||
|  | ||||
|  | ||||
| def load_model(model_path=None): | ||||
|     """加载YOLO模型(优先使用模型管理模块的默认模型)""" | ||||
|     global _yolo_model | ||||
|  | ||||
|     """加载YOLO模型(优先使用带版本校验的默认模型)""" | ||||
|     if model_path is None: | ||||
|         _yolo_model = get_current_yolo_model() | ||||
|         return _yolo_model is not None | ||||
|  | ||||
|         # 调用带版本校验的模型获取函数(自动判断是否需要重新加载) | ||||
|         return get_current_yolo_model() | ||||
|     try: | ||||
|         _yolo_model = YOLO(model_path) | ||||
|         return True | ||||
|         # 加载指定路径模型(用于特殊场景) | ||||
|         return YOLO(model_path) | ||||
|     except Exception as e: | ||||
|         print(f"YOLO模型加载失败(指定路径):{str(e)}") | ||||
|         return False | ||||
|         return None | ||||
|  | ||||
|  | ||||
| def detect(frame, conf_threshold=0.2): | ||||
|     """执行目标检测,返回(是否成功, 结果字符串)""" | ||||
|     global _yolo_model | ||||
|  | ||||
|     # 确保模型已加载 | ||||
|     if not _yolo_model: | ||||
|         if not load_model(): | ||||
|             return (False, "模型未初始化") | ||||
| def detect(frame, conf_threshold=0.7): | ||||
|     """执行目标检测(仅模型版本变化时重新加载,平时复用缓存)""" | ||||
|     # 获取模型(内部已做版本校验,未变化则直接返回缓存) | ||||
|     current_model = load_model() | ||||
|     if not current_model: | ||||
|         return (False, "未加载到最新YOLO模型") | ||||
|  | ||||
|     if frame is None: | ||||
|         return (False, "无效输入帧") | ||||
|  | ||||
|     try: | ||||
|         # 执行检测(frame应为numpy数组) | ||||
|         results = _yolo_model(frame, conf=conf_threshold, verbose=False) | ||||
|         # 用当前模型执行检测(复用缓存,无额外加载耗时) | ||||
|         results = current_model(frame, conf=conf_threshold, verbose=False) | ||||
|         has_results = len(results[0].boxes) > 0 if results else False | ||||
|  | ||||
|         if not has_results: | ||||
| @ -49,11 +41,17 @@ def detect(frame, conf_threshold=0.2): | ||||
|             cls = int(box.cls[0]) | ||||
|             conf = float(box.conf[0]) | ||||
|             bbox = [round(x, 2) for x in box.xyxy[0].tolist()]  # 保留两位小数 | ||||
|             class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}" | ||||
|             # 从当前模型中获取类别名(确保与模型匹配) | ||||
|             class_name = current_model.names[cls] if hasattr(current_model, 'names') else f"类别{cls}" | ||||
|             result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox})") | ||||
|  | ||||
|         # 打印当前使用的模型路径和版本(用于验证) | ||||
|         # model_path = getattr(current_model, "model_path", "未知路径") | ||||
|         # from service.model_service import _current_model_version | ||||
|         # print(f"[YOLO检测] 使用模型:{model_path}(版本:{_current_model_version[:10]}...)") | ||||
|  | ||||
|         return (True, "; ".join(result_parts)) | ||||
|  | ||||
|     except Exception as e: | ||||
|         print(f"检测过程出错:{str(e)}") | ||||
|         print(f"YOLO检测过程出错:{str(e)}") | ||||
|         return (False, f"检测错误:{str(e)}") | ||||
|  | ||||
							
								
								
									
										14
									
								
								ds/db.py
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								ds/db.py
									
									
									
									
									
								
							| @ -3,6 +3,8 @@ from mysql.connector import Error | ||||
|  | ||||
| from .config import MYSQL_CONFIG | ||||
|  | ||||
| # 关键:声明类级别的连接池实例(必须有这一行!) | ||||
| _connection_pool = None  # 确保这一行存在,且拼写正确 | ||||
|  | ||||
| class Database: | ||||
|     """MySQL 连接池管理类""" | ||||
| @ -41,6 +43,18 @@ class Database: | ||||
|         except Error as e: | ||||
|             raise Exception(f"MySQL 连接关闭失败: {str(e)}") from e | ||||
|  | ||||
|     @classmethod | ||||
|     def close_all_connections(cls): | ||||
|         """清理连接池(服务重启前调用)""" | ||||
|         try: | ||||
|             # 先检查属性是否存在,再判断是否有值 | ||||
|             if hasattr(cls, "_connection_pool") and cls._connection_pool: | ||||
|                 cls._connection_pool = None  # 重置连接池 | ||||
|                 print("[Database] 连接池已重置,旧连接将被自动清理") | ||||
|             else: | ||||
|                 print("[Database] 连接池未初始化或已重置,无需操作") | ||||
|         except Exception as e: | ||||
|             print(f"[Database] 重置连接池失败: {str(e)}") | ||||
|  | ||||
| # 暴露数据库操作工具 | ||||
| db = Database() | ||||
|  | ||||
| @ -8,32 +8,48 @@ from pydantic import BaseModel, Field | ||||
| # 请求模型 | ||||
| # ------------------------------ | ||||
| class DeviceCreateRequest(BaseModel): | ||||
|     """设备流信息创建请求模型(与数据库表字段对齐)""" | ||||
|     """设备创建请求模型""" | ||||
|     ip: Optional[str] = Field(..., max_length=100, description="设备IP地址") | ||||
|     hostname: Optional[str] = Field(None, max_length=100, description="设备别名") | ||||
|     params: Optional[Dict] = Field(None, description="设备详细信息(JSON格式)") | ||||
|     params: Optional[Dict] = Field(None, description="设备扩展参数(JSON格式)") | ||||
|  | ||||
|  | ||||
| # ------------------------------ | ||||
| # 响应模型(后端返回数据)- 严格对齐数据库表字段 | ||||
| # 响应模型 | ||||
| # ------------------------------ | ||||
| class DeviceResponse(BaseModel): | ||||
|     """设备流信息响应模型(与数据库表字段完全一致)""" | ||||
|     """单设备信息响应模型(与数据库表字段对齐)""" | ||||
|     id: int = Field(..., description="设备主键ID") | ||||
|     client_ip: Optional[str] = Field(None, max_length=100, description="设备IP地址") | ||||
|     hostname: Optional[str] = Field(None, max_length=100, description="设备别名") | ||||
|     device_online_status: int = Field(..., description="设备在线状态(1-在线、0-离线)") | ||||
|     device_online_status: int = Field(..., description="在线状态(1-在线、0-离线)") | ||||
|     device_type: Optional[str] = Field(None, description="设备类型") | ||||
|     alarm_count: int = Field(..., description="报警次数") | ||||
|     params: Optional[str] = Field(None, description="设备详细信息(JSON字符串)") | ||||
|     params: Optional[str] = Field(None, description="扩展参数(JSON字符串)") | ||||
|     created_at: datetime = Field(..., description="记录创建时间") | ||||
|     updated_at: datetime = Field(..., description="记录更新时间") | ||||
|  | ||||
|     # 支持从数据库查询结果直接转换 | ||||
|     model_config = {"from_attributes": True} | ||||
|     model_config = {"from_attributes": True}  # 支持从数据库结果直接转换 | ||||
|  | ||||
|  | ||||
| class DeviceListResponse(BaseModel): | ||||
|     """设备流信息列表响应模型""" | ||||
|     """设备列表响应模型""" | ||||
|     total: int = Field(..., description="设备总数") | ||||
|     devices: List[DeviceResponse] = Field(..., description="设备列表") | ||||
|  | ||||
|  | ||||
| class DeviceStatusHistoryResponse(BaseModel): | ||||
|     """设备上下线记录响应模型""" | ||||
|     id: int = Field(..., description="记录ID") | ||||
|     device_id: int = Field(..., description="关联设备ID") | ||||
|     client_ip: Optional[str] = Field(None, description="设备IP地址") | ||||
|     status: int = Field(..., description="状态(1-在线、0-离线)") | ||||
|     status_time: datetime = Field(..., description="状态变更时间") | ||||
|  | ||||
|     model_config = {"from_attributes": True} | ||||
|  | ||||
|  | ||||
| class DeviceStatusHistoryListResponse(BaseModel): | ||||
|     """设备上下线记录列表响应模型""" | ||||
|     total: int = Field(..., description="记录总数") | ||||
|     history: List[DeviceStatusHistoryResponse] = Field(..., description="上下线记录列表") | ||||
| @ -1,10 +1,14 @@ | ||||
| import json | ||||
| from datetime import datetime, date | ||||
|  | ||||
| from fastapi import APIRouter, Query, HTTPException,Request | ||||
| from fastapi import APIRouter, Query, HTTPException, Request, Path | ||||
| from mysql.connector import Error as MySQLError | ||||
|  | ||||
| from ds.db import db | ||||
| from schema.device_schema import DeviceCreateRequest, DeviceResponse, DeviceListResponse | ||||
| from schema.device_schema import ( | ||||
|     DeviceCreateRequest, DeviceResponse, DeviceListResponse, | ||||
|     DeviceStatusHistoryResponse, DeviceStatusHistoryListResponse | ||||
| ) | ||||
| from schema.response_schema import APIResponse | ||||
|  | ||||
| router = APIRouter( | ||||
| @ -13,12 +17,53 @@ router = APIRouter( | ||||
| ) | ||||
|  | ||||
|  | ||||
| # ------------------------------ | ||||
| # 内部工具方法 - 记录设备状态变更历史 | ||||
| # ------------------------------ | ||||
| def record_status_change(client_ip: str, status: int) -> bool: | ||||
|     """ | ||||
|     记录设备状态变更历史(写入 device_action 表) | ||||
|  | ||||
|     :param client_ip: 设备IP | ||||
|     :param status: 状态(1-在线、0-离线) | ||||
|     :return: 操作是否成功 | ||||
|     """ | ||||
|     if not client_ip: | ||||
|         raise ValueError("客户端IP不能为空") | ||||
|  | ||||
|     if status not in (0, 1): | ||||
|         raise ValueError("状态必须是0(离线)或1(在线)") | ||||
|  | ||||
|     conn = None | ||||
|     cursor = None | ||||
|     try: | ||||
|         conn = db.get_connection() | ||||
|         cursor = conn.cursor(dictionary=True) | ||||
|  | ||||
|         # 插入状态变更记录到 device_action | ||||
|         insert_query = """ | ||||
|             INSERT INTO device_action  | ||||
|             (client_ip, action, created_at, updated_at) | ||||
|             VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) | ||||
|         """ | ||||
|         cursor.execute(insert_query, (client_ip, status)) | ||||
|         conn.commit() | ||||
|  | ||||
|         return True | ||||
|     except MySQLError as e: | ||||
|         if conn: | ||||
|             conn.rollback() | ||||
|         raise Exception(f"记录设备状态变更失败: {str(e)}") from e | ||||
|     finally: | ||||
|         db.close_connection(conn, cursor) | ||||
|  | ||||
|  | ||||
| # ------------------------------ | ||||
| # 内部工具方法 - 通过客户端IP增加设备报警次数 | ||||
| # ------------------------------ | ||||
| def increment_alarm_count_by_ip(client_ip: str) -> bool: | ||||
|     """ | ||||
|     通过客户端IP增加设备的报警次数(内部服务方法) | ||||
|     通过客户端IP增加设备的报警次数 | ||||
|  | ||||
|     :param client_ip: 客户端IP地址 | ||||
|     :return: 操作是否成功 | ||||
| @ -34,7 +79,8 @@ def increment_alarm_count_by_ip(client_ip: str) -> bool: | ||||
|  | ||||
|         # 检查设备是否存在 | ||||
|         cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,)) | ||||
|         if not cursor.fetchone(): | ||||
|         device = cursor.fetchone() | ||||
|         if not device: | ||||
|             raise ValueError(f"客户端IP为 {client_ip} 的设备不存在") | ||||
|  | ||||
|         # 报警次数加1、并更新时间戳 | ||||
| @ -61,7 +107,7 @@ def increment_alarm_count_by_ip(client_ip: str) -> bool: | ||||
| # ------------------------------ | ||||
| def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: | ||||
|     """ | ||||
|     通过客户端IP更新设备的在线状态(内部服务方法) | ||||
|     通过客户端IP更新设备的在线状态 | ||||
|  | ||||
|     :param client_ip: 客户端IP地址 | ||||
|     :param online_status: 在线状态(1-在线、0-离线) | ||||
| @ -70,7 +116,6 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: | ||||
|     if not client_ip: | ||||
|         raise ValueError("客户端IP不能为空") | ||||
|  | ||||
|     # 验证状态值有效性 | ||||
|     if online_status not in (0, 1): | ||||
|         raise ValueError("在线状态必须是0(离线)或1(在线)") | ||||
|  | ||||
| @ -80,11 +125,16 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: | ||||
|         conn = db.get_connection() | ||||
|         cursor = conn.cursor(dictionary=True) | ||||
|  | ||||
|         # 检查设备是否存在 | ||||
|         cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,)) | ||||
|         if not cursor.fetchone(): | ||||
|         # 检查设备是否存在并获取设备ID | ||||
|         cursor.execute("SELECT id, device_online_status FROM devices WHERE client_ip = %s", (client_ip,)) | ||||
|         device = cursor.fetchone() | ||||
|         if not device: | ||||
|             raise ValueError(f"客户端IP为 {client_ip} 的设备不存在") | ||||
|  | ||||
|         # 状态无变化则不操作 | ||||
|         if device['device_online_status'] == online_status: | ||||
|             return True | ||||
|  | ||||
|         # 更新在线状态和时间戳 | ||||
|         update_query = """ | ||||
|             UPDATE devices  | ||||
| @ -93,8 +143,11 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: | ||||
|             WHERE client_ip = %s | ||||
|         """ | ||||
|         cursor.execute(update_query, (online_status, client_ip)) | ||||
|         conn.commit() | ||||
|  | ||||
|         # 记录状态变更历史 | ||||
|         record_status_change(client_ip, online_status) | ||||
|  | ||||
|         conn.commit() | ||||
|         return True | ||||
|     except MySQLError as e: | ||||
|         if conn: | ||||
| @ -105,46 +158,43 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: | ||||
|  | ||||
|  | ||||
| # ------------------------------ | ||||
| # 原有接口保持不变 | ||||
| # 创建设备信息接口 | ||||
| # ------------------------------ | ||||
| @router.post("/add", response_model=APIResponse, summary="创建设备信息") | ||||
| async def create_device(device_data: DeviceCreateRequest, request: Request):  # 注入Request对象 | ||||
|     # 原有代码保持不变 | ||||
| async def create_device(device_data: DeviceCreateRequest, request: Request): | ||||
|     conn = None | ||||
|     cursor = None | ||||
|     try: | ||||
|         conn = db.get_connection() | ||||
|         cursor = conn.cursor(dictionary=True) | ||||
|  | ||||
|         # 检查设备是否已存在 | ||||
|         cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,)) | ||||
|         existing_device = cursor.fetchone() | ||||
|         if existing_device: | ||||
|             # 更新设备状态为在线 | ||||
|             # 更新设备为在线状态 | ||||
|             update_online_status_by_ip(client_ip=device_data.ip, online_status=1) | ||||
|             # 返回信息 | ||||
|             return APIResponse( | ||||
|                 code=200, | ||||
|                 message=f"设备IP {device_data.ip} 已存在、返回已有设备信息", | ||||
|                 data=DeviceResponse(**existing_device) | ||||
|             ) | ||||
|  | ||||
|         # 直接使用注入的request对象获取用户代理 | ||||
|         # 通过 User-Agent 判断设备类型 | ||||
|         user_agent = request.headers.get("User-Agent", "").lower() | ||||
|  | ||||
|         device_type = "unknown" | ||||
|         if user_agent == "default": | ||||
|             device_type = device_data.params.get("os") if ( | ||||
|                     device_data.params and isinstance(device_data.params, dict)) else "unknown" | ||||
|             device_type = device_data.params.get("os") if (device_data.params and isinstance(device_data.params, dict)) else "unknown" | ||||
|         elif "windows" in user_agent: | ||||
|             device_type = "windows" | ||||
|         elif "android" in user_agent: | ||||
|             device_type = "android" | ||||
|         elif "linux" in user_agent: | ||||
|             device_type = "linux" | ||||
|         else: | ||||
|             device_type = "unknown" | ||||
|  | ||||
|         device_params_json = json.dumps(device_data.params) if device_data.params else None | ||||
|  | ||||
|         # 插入新设备 | ||||
|         insert_query = """ | ||||
|             INSERT INTO devices  | ||||
|             (client_ip, hostname, device_online_status, device_type, alarm_count, params) | ||||
| @ -160,10 +210,14 @@ async def create_device(device_data: DeviceCreateRequest, request: Request):  # | ||||
|         )) | ||||
|         conn.commit() | ||||
|  | ||||
|         # 获取新设备并返回 | ||||
|         device_id = cursor.lastrowid | ||||
|         cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,)) | ||||
|         new_device = cursor.fetchone() | ||||
|  | ||||
|         # 记录上线历史 | ||||
|         record_status_change(device_data.ip, 1) | ||||
|  | ||||
|         return APIResponse( | ||||
|             code=200, | ||||
|             message="设备创建成功", | ||||
| @ -175,7 +229,7 @@ async def create_device(device_data: DeviceCreateRequest, request: Request):  # | ||||
|             conn.rollback() | ||||
|         raise Exception(f"创建设备失败: {str(e)}") from e | ||||
|     except json.JSONDecodeError as e: | ||||
|         raise Exception(f"设备详细信息JSON序列化失败: {str(e)}") from e | ||||
|         raise Exception(f"设备参数JSON序列化失败: {str(e)}") from e | ||||
|     except Exception as e: | ||||
|         if conn: | ||||
|             conn.rollback() | ||||
| @ -183,14 +237,17 @@ async def create_device(device_data: DeviceCreateRequest, request: Request):  # | ||||
|     finally: | ||||
|         db.close_connection(conn, cursor) | ||||
|  | ||||
|  | ||||
| # ------------------------------ | ||||
| # 获取设备列表接口 | ||||
| # ------------------------------ | ||||
| @router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)") | ||||
| async def get_device_list( | ||||
|         page: int = Query(1, ge=1, description="页码、默认第1页"), | ||||
|         page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"), | ||||
|         page: int = Query(1, ge=1, description="页码,默认第1页"), | ||||
|         page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"), | ||||
|         device_type: str = Query(None, description="按设备类型筛选"), | ||||
|         online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选") | ||||
| ): | ||||
|     # 原有代码保持不变 | ||||
|     conn = None | ||||
|     cursor = None | ||||
|     try: | ||||
| @ -207,12 +264,14 @@ async def get_device_list( | ||||
|             where_clause.append("device_online_status = %s") | ||||
|             params.append(online_status) | ||||
|  | ||||
|         # 统计总数 | ||||
|         count_query = "SELECT COUNT(*) AS total FROM devices" | ||||
|         if where_clause: | ||||
|             count_query += " WHERE " + " AND ".join(where_clause) | ||||
|         cursor.execute(count_query, params) | ||||
|         total = cursor.fetchone()["total"] | ||||
|  | ||||
|         # 分页查询列表 | ||||
|         offset = (page - 1) * page_size | ||||
|         list_query = "SELECT * FROM devices" | ||||
|         if where_clause: | ||||
| @ -238,23 +297,140 @@ async def get_device_list( | ||||
|         db.close_connection(conn, cursor) | ||||
|  | ||||
|  | ||||
| def get_unique_client_ips() -> list[str]: | ||||
|     """ | ||||
|     获取所有去重的客户端IP列表 | ||||
|  | ||||
|     :return: 去重后的客户端IP字符串列表,如果没有数据则返回空列表 | ||||
|     """ | ||||
| # ------------------------------ | ||||
| # 获取设备上下线记录接口 | ||||
| # ------------------------------ | ||||
| @router.get("/{device_id}/status-history", response_model=APIResponse, summary="获取设备上下线记录") | ||||
| async def get_device_status_history( | ||||
|         device_id: int = Path(..., description="设备ID"), | ||||
|         page: int = Query(1, ge=1, description="页码,默认第1页"), | ||||
|         page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"), | ||||
|         start_date: date = Query(None, description="开始日期,格式YYYY-MM-DD"), | ||||
|         end_date: date = Query(None, description="结束日期,格式YYYY-MM-DD") | ||||
| ): | ||||
|     conn = None | ||||
|     cursor = None | ||||
|     try: | ||||
|         conn = db.get_connection() | ||||
|         cursor = conn.cursor(dictionary=True) | ||||
|  | ||||
|         # 检查设备是否存在并获取 client_ip | ||||
|         cursor.execute("SELECT id, client_ip FROM devices WHERE id = %s", (device_id,)) | ||||
|         device = cursor.fetchone() | ||||
|         if not device: | ||||
|             raise HTTPException(status_code=404, detail=f"设备ID为 {device_id} 的设备不存在") | ||||
|         client_ip = device['client_ip'] | ||||
|  | ||||
|         where_clause = ["client_ip = %s"] | ||||
|         params = [client_ip] | ||||
|  | ||||
|         # 日期筛选 | ||||
|         if start_date: | ||||
|             where_clause.append("DATE(created_at) >= %s") | ||||
|             params.append(start_date.strftime("%Y-%m-%d")) | ||||
|         if end_date: | ||||
|             where_clause.append("DATE(created_at) <= %s") | ||||
|             params.append(end_date.strftime("%Y-%m-%d")) | ||||
|  | ||||
|         # 统计记录总数 | ||||
|         count_query = "SELECT COUNT(*) AS total FROM device_action WHERE " + " AND ".join(where_clause) | ||||
|         cursor.execute(count_query, params) | ||||
|         total = cursor.fetchone()["total"] | ||||
|  | ||||
|         # 分页查询记录 | ||||
|         offset = (page - 1) * page_size | ||||
|         list_query = f""" | ||||
|             SELECT * FROM device_action  | ||||
|             WHERE {' AND '.join(where_clause)} | ||||
|             ORDER BY created_at DESC  | ||||
|             LIMIT %s OFFSET %s | ||||
|         """ | ||||
|         params.extend([page_size, offset]) | ||||
|         cursor.execute(list_query, params) | ||||
|         history_list = cursor.fetchall() | ||||
|  | ||||
|         # 格式化为响应模型结构 | ||||
|         formatted_history = [] | ||||
|         for item in history_list: | ||||
|             formatted_item = { | ||||
|                 "id": item["id"], | ||||
|                 "device_id": device_id, | ||||
|                 "client_ip": item["client_ip"], | ||||
|                 "status": item["action"], | ||||
|                 "status_time": item["created_at"] | ||||
|             } | ||||
|             formatted_history.append(formatted_item) | ||||
|  | ||||
|         return APIResponse( | ||||
|             code=200, | ||||
|             message="获取设备上下线记录成功", | ||||
|             data=DeviceStatusHistoryListResponse( | ||||
|                 total=total, | ||||
|                 history=[DeviceStatusHistoryResponse(**item) for item in formatted_history] | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     except MySQLError as e: | ||||
|         raise Exception(f"获取设备上下线记录失败: {str(e)}") from e | ||||
|     finally: | ||||
|         db.close_connection(conn, cursor) | ||||
|  | ||||
|  | ||||
| # ------------------------------ | ||||
| # 手动更新设备在线状态接口 | ||||
| # ------------------------------ | ||||
| @router.put("/{device_id}/status", response_model=APIResponse, summary="更新设备在线状态") | ||||
| async def update_device_status( | ||||
|         device_id: int = Path(..., description="设备ID"), | ||||
|         status: int = Query(..., ge=0, le=1, description="在线状态(1-在线、0-离线)") | ||||
| ): | ||||
|     conn = None | ||||
|     cursor = None | ||||
|     try: | ||||
|         conn = db.get_connection() | ||||
|         cursor = conn.cursor(dictionary=True) | ||||
|  | ||||
|         # 获取设备 client_ip | ||||
|         cursor.execute("SELECT id, client_ip FROM devices WHERE id = %s", (device_id,)) | ||||
|         device = cursor.fetchone() | ||||
|         if not device: | ||||
|             raise HTTPException(status_code=404, detail=f"设备ID为 {device_id} 的设备不存在") | ||||
|  | ||||
|         # 更新状态 | ||||
|         success = update_online_status_by_ip(device['client_ip'], status) | ||||
|  | ||||
|         if success: | ||||
|             status_text = "在线" if status == 1 else "离线" | ||||
|             return APIResponse( | ||||
|                 code=200, | ||||
|                 message=f"设备已更新为{status_text}状态", | ||||
|                 data={"device_id": device_id, "status": status, "status_text": status_text} | ||||
|             ) | ||||
|         return APIResponse( | ||||
|             code=500, | ||||
|             message="更新设备状态失败", | ||||
|             data=None | ||||
|         ) | ||||
|  | ||||
|     except MySQLError as e: | ||||
|         raise Exception(f"更新设备状态失败: {str(e)}") from e | ||||
|     finally: | ||||
|         db.close_connection(conn, cursor) | ||||
|  | ||||
|  | ||||
| # ------------------------------ | ||||
| # 获取所有去重的客户端IP列表 | ||||
| # ------------------------------ | ||||
| def get_unique_client_ips() -> list[str]: | ||||
|     """获取所有去重的客户端IP列表""" | ||||
|     conn = None | ||||
|     cursor = None | ||||
|     try: | ||||
|         conn = db.get_connection() | ||||
|         cursor = conn.cursor(dictionary=True) | ||||
|  | ||||
|         # 查询去重的客户端IP | ||||
|         query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL" | ||||
|         cursor.execute(query) | ||||
|  | ||||
|         # 提取结果并转换为字符串列表 | ||||
|         results = cursor.fetchall() | ||||
|         return [item['client_ip'] for item in results] | ||||
|  | ||||
|  | ||||
| @ -1,10 +1,13 @@ | ||||
| import subprocess | ||||
| import os | ||||
| import sys | ||||
| import shutil | ||||
| import threading | ||||
| from pathlib import Path | ||||
| from datetime import datetime | ||||
| from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query | ||||
| from fastapi.responses import FileResponse | ||||
| from mysql.connector import Error as MySQLError | ||||
| import os | ||||
| import shutil | ||||
| from pathlib import Path | ||||
| from datetime import datetime | ||||
|  | ||||
| # 复用项目依赖 | ||||
| from ds.db import db | ||||
| @ -15,7 +18,7 @@ from schema.model_schema import ( | ||||
|     ModelListResponse | ||||
| ) | ||||
| from schema.response_schema import APIResponse | ||||
| from util.model_util import load_yolo_model  # 使用修复后的模型加载工具 | ||||
| from util.model_util import load_yolo_model  # 模型加载工具 | ||||
|  | ||||
| # 路径配置 | ||||
| CURRENT_FILE_PATH = Path(__file__).resolve() | ||||
| @ -28,14 +31,63 @@ DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep | ||||
| ALLOWED_MODEL_EXT = {"pt"} | ||||
| MAX_MODEL_SIZE = 100 * 1024 * 1024  # 100MB | ||||
|  | ||||
| # 全局模型变量 | ||||
| global _yolo_model | ||||
| # 全局模型变量(带版本标识) | ||||
| global _yolo_model, _current_model_version | ||||
| _yolo_model = None | ||||
| _current_model_version = None  # 模型版本标识(用于检测模型是否变化) | ||||
|  | ||||
| router = APIRouter(prefix="/models", tags=["模型管理"]) | ||||
|  | ||||
|  | ||||
| # 工具函数:验证模型路径 | ||||
| # 服务重启核心工具函数 | ||||
| def restart_service(): | ||||
|     """重启当前FastAPI服务进程""" | ||||
|     print("\n[服务重启] 检测到默认模型更换,开始清理资源并重启...") | ||||
|     try: | ||||
|         # 关闭所有WebSocket连接 | ||||
|         try: | ||||
|             from ws import connected_clients | ||||
|             if connected_clients: | ||||
|                 print(f"[服务重启] 关闭{len(connected_clients)}个WebSocket旧连接") | ||||
|                 for ip, conn in list(connected_clients.items()): | ||||
|                     try: | ||||
|                         if conn.consumer_task and not conn.consumer_task.done(): | ||||
|                             conn.consumer_task.cancel() | ||||
|                         conn.websocket.close(code=1001, reason="模型更新,服务重启") | ||||
|                         connected_clients.pop(ip) | ||||
|                     except Exception as e: | ||||
|                         print(f"[服务重启] 关闭客户端{ip}连接失败:{str(e)}") | ||||
|         except ImportError: | ||||
|             print("[服务重启] 未找到WebSocket连接管理模块,跳过连接关闭") | ||||
|  | ||||
|         # 关闭数据库连接 | ||||
|         if hasattr(db, "close_all_connections"): | ||||
|             db.close_all_connections() | ||||
|         else: | ||||
|             print("[警告] db模块未实现close_all_connections,可能存在连接泄漏") | ||||
|  | ||||
|         # 启动新进程 | ||||
|         python_exec = sys.executable | ||||
|         current_argv = sys.argv | ||||
|         print(f"[服务重启] 启动新进程:{python_exec} {' '.join(current_argv)}") | ||||
|         subprocess.Popen( | ||||
|             [python_exec] + current_argv, | ||||
|             close_fds=True, | ||||
|             start_new_session=True, | ||||
|             stdout=subprocess.PIPE, | ||||
|             stderr=subprocess.PIPE | ||||
|         ) | ||||
|  | ||||
|         # 退出当前进程 | ||||
|         print("[服务重启] 新进程已启动,当前进程退出") | ||||
|         sys.exit(0) | ||||
|  | ||||
|     except Exception as e: | ||||
|         print(f"[服务重启] 重启失败:{str(e)}") | ||||
|         raise HTTPException(status_code=500, detail=f"模型更换成功,但服务重启失败:{str(e)}") from e | ||||
|  | ||||
|  | ||||
| # 模型路径验证工具函数 | ||||
| def get_valid_model_abs_path(relative_path: str) -> str: | ||||
|     try: | ||||
|         relative_path = relative_path.replace("/", os.sep) | ||||
| @ -87,6 +139,49 @@ def get_valid_model_abs_path(relative_path: str) -> str: | ||||
|         ) from e | ||||
|  | ||||
|  | ||||
| # 对外提供当前模型(带版本校验) | ||||
| def get_current_yolo_model(): | ||||
|     """供检测模块获取当前最新默认模型(仅版本变化时重新加载)""" | ||||
|     global _yolo_model, _current_model_version | ||||
|     conn = None | ||||
|     cursor = None | ||||
|     try: | ||||
|         conn = db.get_connection() | ||||
|         cursor = conn.cursor(dictionary=True) | ||||
|         cursor.execute("SELECT path FROM model WHERE is_default = 1") | ||||
|         default_model = cursor.fetchone() | ||||
|         if not default_model: | ||||
|             print("[get_current_yolo_model] 暂无默认模型") | ||||
|             return None | ||||
|  | ||||
|         # 1. 计算当前默认模型的唯一版本标识 | ||||
|         # (路径哈希 + 文件修改时间戳,确保模型变化时版本变化) | ||||
|         valid_abs_path = get_valid_model_abs_path(default_model["path"]) | ||||
|         model_stat = os.stat(valid_abs_path) | ||||
|         model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" | ||||
|  | ||||
|         # 2. 版本未变化则复用已有模型(核心优化点) | ||||
|         if _yolo_model and _current_model_version == model_version: | ||||
|             # print(f"[get_current_yolo_model] 模型版本未变,复用缓存(版本:{_current_model_version[:10]}...)") | ||||
|             return _yolo_model | ||||
|  | ||||
|         # 3. 版本变化时重新加载模型 | ||||
|         _yolo_model = load_yolo_model(valid_abs_path) | ||||
|         if _yolo_model: | ||||
|             setattr(_yolo_model, "model_path", valid_abs_path) | ||||
|             _current_model_version = model_version  # 更新版本标识 | ||||
|             print(f"[get_current_yolo_model] 模型版本更新,重新加载(新版本:{model_version[:10]}...)") | ||||
|         else: | ||||
|             print(f"[get_current_yolo_model] 加载最新默认模型失败:{valid_abs_path}") | ||||
|         return _yolo_model | ||||
|  | ||||
|     except Exception as e: | ||||
|         print(f"[get_current_yolo_model] 加载失败:{str(e)}") | ||||
|         return None | ||||
|     finally: | ||||
|         db.close_connection(conn, cursor) | ||||
|  | ||||
|  | ||||
| # 1. 上传模型 | ||||
| @router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)") | ||||
| async def upload_model( | ||||
| @ -142,12 +237,16 @@ async def upload_model( | ||||
|         if not new_model: | ||||
|             raise HTTPException(status_code=500, detail="上传成功但无法获取记录") | ||||
|  | ||||
|         # 加载默认模型 | ||||
|         global _yolo_model | ||||
|         # 加载默认模型并更新版本 | ||||
|         global _yolo_model, _current_model_version | ||||
|         if is_default: | ||||
|             valid_abs_path = get_valid_model_abs_path(db_relative_path) | ||||
|             _yolo_model = load_yolo_model(valid_abs_path) | ||||
|             if not _yolo_model: | ||||
|             if _yolo_model: | ||||
|                 setattr(_yolo_model, "model_path", valid_abs_path) | ||||
|                 model_stat = os.stat(valid_abs_path) | ||||
|                 _current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" | ||||
|             else: | ||||
|                 raise HTTPException( | ||||
|                     status_code=500, | ||||
|                     detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path})" | ||||
| @ -246,11 +345,15 @@ async def get_default_model(): | ||||
|             raise HTTPException(status_code=404, detail="暂无默认模型") | ||||
|  | ||||
|         valid_abs_path = get_valid_model_abs_path(default_model["path"]) | ||||
|         global _yolo_model | ||||
|         global _yolo_model, _current_model_version | ||||
|  | ||||
|         if not _yolo_model: | ||||
|             _yolo_model = load_yolo_model(valid_abs_path) | ||||
|             if not _yolo_model: | ||||
|             if _yolo_model: | ||||
|                 setattr(_yolo_model, "model_path", valid_abs_path) | ||||
|                 model_stat = os.stat(valid_abs_path) | ||||
|                 _current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" | ||||
|             else: | ||||
|                 raise HTTPException( | ||||
|                     status_code=500, | ||||
|                     detail=f"默认模型存在,但加载失败(路径:{valid_abs_path})" | ||||
| @ -358,11 +461,16 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest): | ||||
|         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | ||||
|         updated_model = cursor.fetchone() | ||||
|  | ||||
|         global _yolo_model | ||||
|         # 更新模型后重置版本标识 | ||||
|         global _yolo_model, _current_model_version | ||||
|         if need_load_default: | ||||
|             valid_abs_path = get_valid_model_abs_path(updated_model["path"]) | ||||
|             _yolo_model = load_yolo_model(valid_abs_path) | ||||
|             if not _yolo_model: | ||||
|             if _yolo_model: | ||||
|                 setattr(_yolo_model, "model_path", valid_abs_path) | ||||
|                 model_stat = os.stat(valid_abs_path) | ||||
|                 _current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" | ||||
|             else: | ||||
|                 raise HTTPException( | ||||
|                     status_code=500, | ||||
|                     detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path})" | ||||
| @ -382,6 +490,96 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest): | ||||
|         db.close_connection(conn, cursor) | ||||
|  | ||||
|  | ||||
| # 5.1 更换默认模型(自动重启服务) | ||||
| @router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型(自动重启服务)") | ||||
| async def set_default_model(model_id: int): | ||||
|     conn = None | ||||
|     cursor = None | ||||
|     try: | ||||
|         conn = db.get_connection() | ||||
|         cursor = conn.cursor(dictionary=True) | ||||
|         conn.autocommit = False  # 开启事务 | ||||
|  | ||||
|         # 1. 校验目标模型是否存在 | ||||
|         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | ||||
|         target_model = cursor.fetchone() | ||||
|         if not target_model: | ||||
|             raise HTTPException(status_code=404, detail=f"目标模型不存在!ID:{model_id}") | ||||
|  | ||||
|         # 2. 检查是否已为默认模型 | ||||
|         if target_model["is_default"]: | ||||
|             return APIResponse( | ||||
|                 code=200, | ||||
|                 message=f"模型ID:{model_id} 已是默认模型,无需更换和重启", | ||||
|                 data=ModelResponse(**target_model) | ||||
|             ) | ||||
|  | ||||
|         # 3. 校验目标模型文件合法性 | ||||
|         try: | ||||
|             valid_abs_path = get_valid_model_abs_path(target_model["path"]) | ||||
|         except HTTPException as e: | ||||
|             raise HTTPException( | ||||
|                 status_code=400, | ||||
|                 detail=f"目标模型文件非法,无法设为默认:{e.detail}" | ||||
|             ) from e | ||||
|  | ||||
|         # 4. 数据库事务:更新默认模型状态 | ||||
|         try: | ||||
|             cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP") | ||||
|             cursor.execute( | ||||
|                 "UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s", | ||||
|                 (model_id,) | ||||
|             ) | ||||
|             conn.commit() | ||||
|         except MySQLError as e: | ||||
|             conn.rollback() | ||||
|             raise HTTPException( | ||||
|                 status_code=500, | ||||
|                 detail=f"更新默认模型状态失败(已回滚):{str(e)}" | ||||
|             ) from e | ||||
|  | ||||
|         # 5. 验证新模型可加载性 | ||||
|         test_model = load_yolo_model(valid_abs_path) | ||||
|         if not test_model: | ||||
|             conn.rollback() | ||||
|             raise HTTPException( | ||||
|                 status_code=500, | ||||
|                 detail=f"新默认模型加载失败,已回滚状态(路径:{valid_abs_path})" | ||||
|             ) | ||||
|  | ||||
|         # 6. 重新查询更新后的模型信息 | ||||
|         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | ||||
|         updated_model = cursor.fetchone() | ||||
|  | ||||
|         # 7. 重置版本标识(关键:确保下次检测加载新模型) | ||||
|         global _current_model_version | ||||
|         _current_model_version = None | ||||
|         print(f"[更换默认模型] 已重置模型版本标识,下次检测将加载新模型") | ||||
|  | ||||
|         # 8. 延迟重启服务 | ||||
|         print(f"[更换默认模型] 成功!将在1秒后重启服务以应用新模型(ID:{model_id})") | ||||
|         threading.Timer( | ||||
|             interval=1.0, | ||||
|             function=restart_service | ||||
|         ).start() | ||||
|  | ||||
|         # 9. 返回成功响应 | ||||
|         return APIResponse( | ||||
|             code=200, | ||||
|             message=f"已成功更换默认模型(ID:{model_id})!服务将在1秒后自动重启以应用新模型", | ||||
|             data=ModelResponse(**updated_model) | ||||
|         ) | ||||
|  | ||||
|     except MySQLError as e: | ||||
|         if conn: | ||||
|             conn.rollback() | ||||
|         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | ||||
|     finally: | ||||
|         if conn: | ||||
|             conn.autocommit = True | ||||
|         db.close_connection(conn, cursor) | ||||
|  | ||||
|  | ||||
| # 6. 删除模型 | ||||
| @router.delete("/{model_id}", response_model=APIResponse, summary="删除模型") | ||||
| async def delete_model(model_id: int): | ||||
| @ -420,10 +618,12 @@ async def delete_model(model_id: int): | ||||
|         except Exception as e: | ||||
|             extra_msg = f"(文件删除失败:{str(e)})" | ||||
|  | ||||
|         global _yolo_model | ||||
|         if _yolo_model and str(_yolo_model.model_path) == model_abs_path_str: | ||||
|         # 如果删除的是当前加载的模型,重置缓存 | ||||
|         global _yolo_model, _current_model_version | ||||
|         if _yolo_model and str(getattr(_yolo_model, "model_path", "")) == model_abs_path_str: | ||||
|             _yolo_model = None | ||||
|             print(f"[模型删除] 已清空全局模型(路径:{model_abs_path_str})") | ||||
|             _current_model_version = None | ||||
|             print(f"[模型删除] 已清空全局模型缓存(路径:{model_abs_path_str})") | ||||
|  | ||||
|         return APIResponse( | ||||
|             code=200, | ||||
| @ -466,32 +666,3 @@ async def download_model(model_id: int): | ||||
|         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | ||||
|     finally: | ||||
|         db.close_connection(conn, cursor) | ||||
|  | ||||
|  | ||||
| # 对外提供当前模型 | ||||
| def get_current_yolo_model(): | ||||
|     """供检测模块获取当前加载的模型""" | ||||
|     global _yolo_model | ||||
|     if not _yolo_model: | ||||
|         conn = None | ||||
|         cursor = None | ||||
|         try: | ||||
|             conn = db.get_connection() | ||||
|             cursor = conn.cursor(dictionary=True) | ||||
|             cursor.execute("SELECT path FROM model WHERE is_default = 1") | ||||
|             default_model = cursor.fetchone() | ||||
|             if not default_model: | ||||
|                 print("[get_current_yolo_model] 暂无默认模型") | ||||
|                 return None | ||||
|  | ||||
|             valid_abs_path = get_valid_model_abs_path(default_model["path"]) | ||||
|             _yolo_model = load_yolo_model(valid_abs_path) | ||||
|             if _yolo_model: | ||||
|                 print(f"[get_current_yolo_model] 自动加载默认模型成功") | ||||
|             else: | ||||
|                 print(f"[get_current_yolo_model] 自动加载默认模型失败") | ||||
|         except Exception as e: | ||||
|             print(f"[get_current_yolo_model] 加载失败:{str(e)}") | ||||
|         finally: | ||||
|             db.close_connection(conn, cursor) | ||||
|     return _yolo_model | ||||
|  | ||||
| @ -12,7 +12,7 @@ def save_face_to_up_images( | ||||
| ) -> Dict[str, str]: | ||||
|     """ | ||||
|     保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径 | ||||
|     修复路径计算错误,确保所有路径在up_images根目录下 | ||||
|     修复路径计算错误,确保所有路径在up_images根目录下,且统一使用正斜杠 | ||||
|  | ||||
|     参数: | ||||
|         client_ip: 客户端IP(原始格式,如192.168.1.101) | ||||
| @ -38,7 +38,7 @@ def save_face_to_up_images( | ||||
|         safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|'])  # 过滤非法字符 | ||||
|  | ||||
|         # 3. 构建根目录(强制转为绝对路径,避免相对路径混淆) | ||||
|         root_dir = Path("up_images").resolve()  # 转为绝对路径(关键修复!) | ||||
|         root_dir = Path("up_images").resolve()  # 转为绝对路径 | ||||
|         if not root_dir.exists(): | ||||
|             root_dir.mkdir(parents=True, exist_ok=True) | ||||
|             print(f"[FileUtil] 已创建up_images根目录:{root_dir}") | ||||
| @ -53,15 +53,16 @@ def save_face_to_up_images( | ||||
|         timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3] | ||||
|         image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}" | ||||
|  | ||||
|         # 6. 计算路径(关键修复:确保所有路径都是绝对路径且在root_dir下) | ||||
|         # 6. 计算路径(确保所有路径都是绝对路径且在root_dir下) | ||||
|         local_abs_path = face_name_dir / image_filename  # 绝对路径 | ||||
|  | ||||
|         # 验证路径是否在root_dir下(防止路径穿越攻击) | ||||
|         if not local_abs_path.resolve().is_relative_to(root_dir.resolve()): | ||||
|             raise Exception(f"图片路径不在up_images根目录下(安全校验失败):{local_abs_path}") | ||||
|  | ||||
|         # 数据库存储路径:从root_dir开始的相对路径(如 up_images/192_168_110_31/小王/xxx.jpg) | ||||
|         db_path = str(root_dir.name / local_abs_path.relative_to(root_dir)) | ||||
|         # 数据库存储路径:从root_dir开始的相对路径,强制替换为正斜杠 | ||||
|         relative_path = local_abs_path.relative_to(root_dir) | ||||
|         db_path = str(relative_path).replace("\\", "/")  # 关键修复:统一使用正斜杠 | ||||
|  | ||||
|         # 7. 写入图片文件 | ||||
|         with open(local_abs_path, "wb") as f: | ||||
| @ -72,7 +73,7 @@ def save_face_to_up_images( | ||||
|  | ||||
|         return { | ||||
|             "success": True, | ||||
|             "db_path": db_path,  # 存数据库的相对路径(up_images开头) | ||||
|             "db_path": db_path,  # 存数据库的相对路径(使用正斜杠) | ||||
|             "local_abs_path": str(local_abs_path),  # 本地绝对路径 | ||||
|             "msg": "图片保存成功" | ||||
|         } | ||||
|  | ||||
		Reference in New Issue
	
	Block a user