可以成功动态更换yolo模型并重启服务生效
This commit is contained in:
		
							
								
								
									
										48
									
								
								core/yolo.py
									
									
									
									
									
								
							
							
						
						
									
										48
									
								
								core/yolo.py
									
									
									
									
									
								
							| @ -1,43 +1,35 @@ | |||||||
| import os | import os | ||||||
| import numpy as np | import numpy as np | ||||||
| from ultralytics import YOLO | from ultralytics import YOLO | ||||||
| from service.model_service import get_current_yolo_model  # 从模型管理模块获取模型 | from service.model_service import get_current_yolo_model  # 带版本校验的模型获取 | ||||||
|  |  | ||||||
| # 全局模型变量 |  | ||||||
| _yolo_model = None |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_model(model_path=None): | def load_model(model_path=None): | ||||||
|     """加载YOLO模型(优先使用模型管理模块的默认模型)""" |     """加载YOLO模型(优先使用带版本校验的默认模型)""" | ||||||
|     global _yolo_model |  | ||||||
|  |  | ||||||
|     if model_path is None: |     if model_path is None: | ||||||
|         _yolo_model = get_current_yolo_model() |         # 调用带版本校验的模型获取函数(自动判断是否需要重新加载) | ||||||
|         return _yolo_model is not None |         return get_current_yolo_model() | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         _yolo_model = YOLO(model_path) |         # 加载指定路径模型(用于特殊场景) | ||||||
|         return True |         return YOLO(model_path) | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"YOLO模型加载失败(指定路径):{str(e)}") |         print(f"YOLO模型加载失败(指定路径):{str(e)}") | ||||||
|         return False |         return None | ||||||
|  |  | ||||||
|  |  | ||||||
| def detect(frame, conf_threshold=0.2): | def detect(frame, conf_threshold=0.7): | ||||||
|     """执行目标检测,返回(是否成功, 结果字符串)""" |     """执行目标检测(仅模型版本变化时重新加载,平时复用缓存)""" | ||||||
|     global _yolo_model |     # 获取模型(内部已做版本校验,未变化则直接返回缓存) | ||||||
|  |     current_model = load_model() | ||||||
|     # 确保模型已加载 |     if not current_model: | ||||||
|     if not _yolo_model: |         return (False, "未加载到最新YOLO模型") | ||||||
|         if not load_model(): |  | ||||||
|             return (False, "模型未初始化") |  | ||||||
|  |  | ||||||
|     if frame is None: |     if frame is None: | ||||||
|         return (False, "无效输入帧") |         return (False, "无效输入帧") | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         # 执行检测(frame应为numpy数组) |         # 用当前模型执行检测(复用缓存,无额外加载耗时) | ||||||
|         results = _yolo_model(frame, conf=conf_threshold, verbose=False) |         results = current_model(frame, conf=conf_threshold, verbose=False) | ||||||
|         has_results = len(results[0].boxes) > 0 if results else False |         has_results = len(results[0].boxes) > 0 if results else False | ||||||
|  |  | ||||||
|         if not has_results: |         if not has_results: | ||||||
| @ -49,11 +41,17 @@ def detect(frame, conf_threshold=0.2): | |||||||
|             cls = int(box.cls[0]) |             cls = int(box.cls[0]) | ||||||
|             conf = float(box.conf[0]) |             conf = float(box.conf[0]) | ||||||
|             bbox = [round(x, 2) for x in box.xyxy[0].tolist()]  # 保留两位小数 |             bbox = [round(x, 2) for x in box.xyxy[0].tolist()]  # 保留两位小数 | ||||||
|             class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}" |             # 从当前模型中获取类别名(确保与模型匹配) | ||||||
|  |             class_name = current_model.names[cls] if hasattr(current_model, 'names') else f"类别{cls}" | ||||||
|             result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox})") |             result_parts.append(f"{class_name}(置信度:{conf:.2f},位置:{bbox})") | ||||||
|  |  | ||||||
|  |         # 打印当前使用的模型路径和版本(用于验证) | ||||||
|  |         # model_path = getattr(current_model, "model_path", "未知路径") | ||||||
|  |         # from service.model_service import _current_model_version | ||||||
|  |         # print(f"[YOLO检测] 使用模型:{model_path}(版本:{_current_model_version[:10]}...)") | ||||||
|  |  | ||||||
|         return (True, "; ".join(result_parts)) |         return (True, "; ".join(result_parts)) | ||||||
|  |  | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"检测过程出错:{str(e)}") |         print(f"YOLO检测过程出错:{str(e)}") | ||||||
|         return (False, f"检测错误:{str(e)}") |         return (False, f"检测错误:{str(e)}") | ||||||
|  | |||||||
							
								
								
									
										14
									
								
								ds/db.py
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								ds/db.py
									
									
									
									
									
								
							| @ -3,6 +3,8 @@ from mysql.connector import Error | |||||||
|  |  | ||||||
| from .config import MYSQL_CONFIG | from .config import MYSQL_CONFIG | ||||||
|  |  | ||||||
|  | # 关键:声明类级别的连接池实例(必须有这一行!) | ||||||
|  | _connection_pool = None  # 确保这一行存在,且拼写正确 | ||||||
|  |  | ||||||
| class Database: | class Database: | ||||||
|     """MySQL 连接池管理类""" |     """MySQL 连接池管理类""" | ||||||
| @ -41,6 +43,18 @@ class Database: | |||||||
|         except Error as e: |         except Error as e: | ||||||
|             raise Exception(f"MySQL 连接关闭失败: {str(e)}") from e |             raise Exception(f"MySQL 连接关闭失败: {str(e)}") from e | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def close_all_connections(cls): | ||||||
|  |         """清理连接池(服务重启前调用)""" | ||||||
|  |         try: | ||||||
|  |             # 先检查属性是否存在,再判断是否有值 | ||||||
|  |             if hasattr(cls, "_connection_pool") and cls._connection_pool: | ||||||
|  |                 cls._connection_pool = None  # 重置连接池 | ||||||
|  |                 print("[Database] 连接池已重置,旧连接将被自动清理") | ||||||
|  |             else: | ||||||
|  |                 print("[Database] 连接池未初始化或已重置,无需操作") | ||||||
|  |         except Exception as e: | ||||||
|  |             print(f"[Database] 重置连接池失败: {str(e)}") | ||||||
|  |  | ||||||
| # 暴露数据库操作工具 | # 暴露数据库操作工具 | ||||||
| db = Database() | db = Database() | ||||||
|  | |||||||
| @ -8,32 +8,48 @@ from pydantic import BaseModel, Field | |||||||
| # 请求模型 | # 请求模型 | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| class DeviceCreateRequest(BaseModel): | class DeviceCreateRequest(BaseModel): | ||||||
|     """设备流信息创建请求模型(与数据库表字段对齐)""" |     """设备创建请求模型""" | ||||||
|     ip: Optional[str] = Field(..., max_length=100, description="设备IP地址") |     ip: Optional[str] = Field(..., max_length=100, description="设备IP地址") | ||||||
|     hostname: Optional[str] = Field(None, max_length=100, description="设备别名") |     hostname: Optional[str] = Field(None, max_length=100, description="设备别名") | ||||||
|     params: Optional[Dict] = Field(None, description="设备详细信息(JSON格式)") |     params: Optional[Dict] = Field(None, description="设备扩展参数(JSON格式)") | ||||||
|  |  | ||||||
|  |  | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| # 响应模型(后端返回数据)- 严格对齐数据库表字段 | # 响应模型 | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| class DeviceResponse(BaseModel): | class DeviceResponse(BaseModel): | ||||||
|     """设备流信息响应模型(与数据库表字段完全一致)""" |     """单设备信息响应模型(与数据库表字段对齐)""" | ||||||
|     id: int = Field(..., description="设备主键ID") |     id: int = Field(..., description="设备主键ID") | ||||||
|     client_ip: Optional[str] = Field(None, max_length=100, description="设备IP地址") |     client_ip: Optional[str] = Field(None, max_length=100, description="设备IP地址") | ||||||
|     hostname: Optional[str] = Field(None, max_length=100, description="设备别名") |     hostname: Optional[str] = Field(None, max_length=100, description="设备别名") | ||||||
|     device_online_status: int = Field(..., description="设备在线状态(1-在线、0-离线)") |     device_online_status: int = Field(..., description="在线状态(1-在线、0-离线)") | ||||||
|     device_type: Optional[str] = Field(None, description="设备类型") |     device_type: Optional[str] = Field(None, description="设备类型") | ||||||
|     alarm_count: int = Field(..., description="报警次数") |     alarm_count: int = Field(..., description="报警次数") | ||||||
|     params: Optional[str] = Field(None, description="设备详细信息(JSON字符串)") |     params: Optional[str] = Field(None, description="扩展参数(JSON字符串)") | ||||||
|     created_at: datetime = Field(..., description="记录创建时间") |     created_at: datetime = Field(..., description="记录创建时间") | ||||||
|     updated_at: datetime = Field(..., description="记录更新时间") |     updated_at: datetime = Field(..., description="记录更新时间") | ||||||
|  |  | ||||||
|     # 支持从数据库查询结果直接转换 |     model_config = {"from_attributes": True}  # 支持从数据库结果直接转换 | ||||||
|     model_config = {"from_attributes": True} |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DeviceListResponse(BaseModel): | class DeviceListResponse(BaseModel): | ||||||
|     """设备流信息列表响应模型""" |     """设备列表响应模型""" | ||||||
|     total: int = Field(..., description="设备总数") |     total: int = Field(..., description="设备总数") | ||||||
|     devices: List[DeviceResponse] = Field(..., description="设备列表") |     devices: List[DeviceResponse] = Field(..., description="设备列表") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DeviceStatusHistoryResponse(BaseModel): | ||||||
|  |     """设备上下线记录响应模型""" | ||||||
|  |     id: int = Field(..., description="记录ID") | ||||||
|  |     device_id: int = Field(..., description="关联设备ID") | ||||||
|  |     client_ip: Optional[str] = Field(None, description="设备IP地址") | ||||||
|  |     status: int = Field(..., description="状态(1-在线、0-离线)") | ||||||
|  |     status_time: datetime = Field(..., description="状态变更时间") | ||||||
|  |  | ||||||
|  |     model_config = {"from_attributes": True} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DeviceStatusHistoryListResponse(BaseModel): | ||||||
|  |     """设备上下线记录列表响应模型""" | ||||||
|  |     total: int = Field(..., description="记录总数") | ||||||
|  |     history: List[DeviceStatusHistoryResponse] = Field(..., description="上下线记录列表") | ||||||
| @ -1,10 +1,14 @@ | |||||||
| import json | import json | ||||||
|  | from datetime import datetime, date | ||||||
|  |  | ||||||
| from fastapi import APIRouter, Query, HTTPException,Request | from fastapi import APIRouter, Query, HTTPException, Request, Path | ||||||
| from mysql.connector import Error as MySQLError | from mysql.connector import Error as MySQLError | ||||||
|  |  | ||||||
| from ds.db import db | from ds.db import db | ||||||
| from schema.device_schema import DeviceCreateRequest, DeviceResponse, DeviceListResponse | from schema.device_schema import ( | ||||||
|  |     DeviceCreateRequest, DeviceResponse, DeviceListResponse, | ||||||
|  |     DeviceStatusHistoryResponse, DeviceStatusHistoryListResponse | ||||||
|  | ) | ||||||
| from schema.response_schema import APIResponse | from schema.response_schema import APIResponse | ||||||
|  |  | ||||||
| router = APIRouter( | router = APIRouter( | ||||||
| @ -13,12 +17,53 @@ router = APIRouter( | |||||||
| ) | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 内部工具方法 - 记录设备状态变更历史 | ||||||
|  | # ------------------------------ | ||||||
|  | def record_status_change(client_ip: str, status: int) -> bool: | ||||||
|  |     """ | ||||||
|  |     记录设备状态变更历史(写入 device_action 表) | ||||||
|  |  | ||||||
|  |     :param client_ip: 设备IP | ||||||
|  |     :param status: 状态(1-在线、0-离线) | ||||||
|  |     :return: 操作是否成功 | ||||||
|  |     """ | ||||||
|  |     if not client_ip: | ||||||
|  |         raise ValueError("客户端IP不能为空") | ||||||
|  |  | ||||||
|  |     if status not in (0, 1): | ||||||
|  |         raise ValueError("状态必须是0(离线)或1(在线)") | ||||||
|  |  | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 插入状态变更记录到 device_action | ||||||
|  |         insert_query = """ | ||||||
|  |             INSERT INTO device_action  | ||||||
|  |             (client_ip, action, created_at, updated_at) | ||||||
|  |             VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) | ||||||
|  |         """ | ||||||
|  |         cursor.execute(insert_query, (client_ip, status)) | ||||||
|  |         conn.commit() | ||||||
|  |  | ||||||
|  |         return True | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise Exception(f"记录设备状态变更失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| # 内部工具方法 - 通过客户端IP增加设备报警次数 | # 内部工具方法 - 通过客户端IP增加设备报警次数 | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| def increment_alarm_count_by_ip(client_ip: str) -> bool: | def increment_alarm_count_by_ip(client_ip: str) -> bool: | ||||||
|     """ |     """ | ||||||
|     通过客户端IP增加设备的报警次数(内部服务方法) |     通过客户端IP增加设备的报警次数 | ||||||
|  |  | ||||||
|     :param client_ip: 客户端IP地址 |     :param client_ip: 客户端IP地址 | ||||||
|     :return: 操作是否成功 |     :return: 操作是否成功 | ||||||
| @ -34,7 +79,8 @@ def increment_alarm_count_by_ip(client_ip: str) -> bool: | |||||||
|  |  | ||||||
|         # 检查设备是否存在 |         # 检查设备是否存在 | ||||||
|         cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,)) |         cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,)) | ||||||
|         if not cursor.fetchone(): |         device = cursor.fetchone() | ||||||
|  |         if not device: | ||||||
|             raise ValueError(f"客户端IP为 {client_ip} 的设备不存在") |             raise ValueError(f"客户端IP为 {client_ip} 的设备不存在") | ||||||
|  |  | ||||||
|         # 报警次数加1、并更新时间戳 |         # 报警次数加1、并更新时间戳 | ||||||
| @ -61,7 +107,7 @@ def increment_alarm_count_by_ip(client_ip: str) -> bool: | |||||||
| # ------------------------------ | # ------------------------------ | ||||||
| def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: | def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: | ||||||
|     """ |     """ | ||||||
|     通过客户端IP更新设备的在线状态(内部服务方法) |     通过客户端IP更新设备的在线状态 | ||||||
|  |  | ||||||
|     :param client_ip: 客户端IP地址 |     :param client_ip: 客户端IP地址 | ||||||
|     :param online_status: 在线状态(1-在线、0-离线) |     :param online_status: 在线状态(1-在线、0-离线) | ||||||
| @ -70,7 +116,6 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: | |||||||
|     if not client_ip: |     if not client_ip: | ||||||
|         raise ValueError("客户端IP不能为空") |         raise ValueError("客户端IP不能为空") | ||||||
|  |  | ||||||
|     # 验证状态值有效性 |  | ||||||
|     if online_status not in (0, 1): |     if online_status not in (0, 1): | ||||||
|         raise ValueError("在线状态必须是0(离线)或1(在线)") |         raise ValueError("在线状态必须是0(离线)或1(在线)") | ||||||
|  |  | ||||||
| @ -80,11 +125,16 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: | |||||||
|         conn = db.get_connection() |         conn = db.get_connection() | ||||||
|         cursor = conn.cursor(dictionary=True) |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|         # 检查设备是否存在 |         # 检查设备是否存在并获取设备ID | ||||||
|         cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,)) |         cursor.execute("SELECT id, device_online_status FROM devices WHERE client_ip = %s", (client_ip,)) | ||||||
|         if not cursor.fetchone(): |         device = cursor.fetchone() | ||||||
|  |         if not device: | ||||||
|             raise ValueError(f"客户端IP为 {client_ip} 的设备不存在") |             raise ValueError(f"客户端IP为 {client_ip} 的设备不存在") | ||||||
|  |  | ||||||
|  |         # 状态无变化则不操作 | ||||||
|  |         if device['device_online_status'] == online_status: | ||||||
|  |             return True | ||||||
|  |  | ||||||
|         # 更新在线状态和时间戳 |         # 更新在线状态和时间戳 | ||||||
|         update_query = """ |         update_query = """ | ||||||
|             UPDATE devices  |             UPDATE devices  | ||||||
| @ -93,8 +143,11 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: | |||||||
|             WHERE client_ip = %s |             WHERE client_ip = %s | ||||||
|         """ |         """ | ||||||
|         cursor.execute(update_query, (online_status, client_ip)) |         cursor.execute(update_query, (online_status, client_ip)) | ||||||
|         conn.commit() |  | ||||||
|  |  | ||||||
|  |         # 记录状态变更历史 | ||||||
|  |         record_status_change(client_ip, online_status) | ||||||
|  |  | ||||||
|  |         conn.commit() | ||||||
|         return True |         return True | ||||||
|     except MySQLError as e: |     except MySQLError as e: | ||||||
|         if conn: |         if conn: | ||||||
| @ -105,46 +158,43 @@ def update_online_status_by_ip(client_ip: str, online_status: int) -> bool: | |||||||
|  |  | ||||||
|  |  | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| # 原有接口保持不变 | # 创建设备信息接口 | ||||||
| # ------------------------------ | # ------------------------------ | ||||||
| @router.post("/add", response_model=APIResponse, summary="创建设备信息") | @router.post("/add", response_model=APIResponse, summary="创建设备信息") | ||||||
| async def create_device(device_data: DeviceCreateRequest, request: Request):  # 注入Request对象 | async def create_device(device_data: DeviceCreateRequest, request: Request): | ||||||
|     # 原有代码保持不变 |  | ||||||
|     conn = None |     conn = None | ||||||
|     cursor = None |     cursor = None | ||||||
|     try: |     try: | ||||||
|         conn = db.get_connection() |         conn = db.get_connection() | ||||||
|         cursor = conn.cursor(dictionary=True) |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 检查设备是否已存在 | ||||||
|         cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,)) |         cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,)) | ||||||
|         existing_device = cursor.fetchone() |         existing_device = cursor.fetchone() | ||||||
|         if existing_device: |         if existing_device: | ||||||
|             # 更新设备状态为在线 |             # 更新设备为在线状态 | ||||||
|             update_online_status_by_ip(client_ip=device_data.ip, online_status=1) |             update_online_status_by_ip(client_ip=device_data.ip, online_status=1) | ||||||
|             # 返回信息 |  | ||||||
|             return APIResponse( |             return APIResponse( | ||||||
|                 code=200, |                 code=200, | ||||||
|                 message=f"设备IP {device_data.ip} 已存在、返回已有设备信息", |                 message=f"设备IP {device_data.ip} 已存在、返回已有设备信息", | ||||||
|                 data=DeviceResponse(** existing_device) |                 data=DeviceResponse(**existing_device) | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|         # 直接使用注入的request对象获取用户代理 |         # 通过 User-Agent 判断设备类型 | ||||||
|         user_agent = request.headers.get("User-Agent", "").lower() |         user_agent = request.headers.get("User-Agent", "").lower() | ||||||
|  |         device_type = "unknown" | ||||||
|         if user_agent == "default": |         if user_agent == "default": | ||||||
|             device_type = device_data.params.get("os") if ( |             device_type = device_data.params.get("os") if (device_data.params and isinstance(device_data.params, dict)) else "unknown" | ||||||
|                     device_data.params and isinstance(device_data.params, dict)) else "unknown" |  | ||||||
|         elif "windows" in user_agent: |         elif "windows" in user_agent: | ||||||
|             device_type = "windows" |             device_type = "windows" | ||||||
|         elif "android" in user_agent: |         elif "android" in user_agent: | ||||||
|             device_type = "android" |             device_type = "android" | ||||||
|         elif "linux" in user_agent: |         elif "linux" in user_agent: | ||||||
|             device_type = "linux" |             device_type = "linux" | ||||||
|         else: |  | ||||||
|             device_type = "unknown" |  | ||||||
|  |  | ||||||
|         device_params_json = json.dumps(device_data.params) if device_data.params else None |         device_params_json = json.dumps(device_data.params) if device_data.params else None | ||||||
|  |  | ||||||
|  |         # 插入新设备 | ||||||
|         insert_query = """ |         insert_query = """ | ||||||
|             INSERT INTO devices  |             INSERT INTO devices  | ||||||
|             (client_ip, hostname, device_online_status, device_type, alarm_count, params) |             (client_ip, hostname, device_online_status, device_type, alarm_count, params) | ||||||
| @ -160,10 +210,14 @@ async def create_device(device_data: DeviceCreateRequest, request: Request):  # | |||||||
|         )) |         )) | ||||||
|         conn.commit() |         conn.commit() | ||||||
|  |  | ||||||
|  |         # 获取新设备并返回 | ||||||
|         device_id = cursor.lastrowid |         device_id = cursor.lastrowid | ||||||
|         cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,)) |         cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,)) | ||||||
|         new_device = cursor.fetchone() |         new_device = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         # 记录上线历史 | ||||||
|  |         record_status_change(device_data.ip, 1) | ||||||
|  |  | ||||||
|         return APIResponse( |         return APIResponse( | ||||||
|             code=200, |             code=200, | ||||||
|             message="设备创建成功", |             message="设备创建成功", | ||||||
| @ -175,7 +229,7 @@ async def create_device(device_data: DeviceCreateRequest, request: Request):  # | |||||||
|             conn.rollback() |             conn.rollback() | ||||||
|         raise Exception(f"创建设备失败: {str(e)}") from e |         raise Exception(f"创建设备失败: {str(e)}") from e | ||||||
|     except json.JSONDecodeError as e: |     except json.JSONDecodeError as e: | ||||||
|         raise Exception(f"设备详细信息JSON序列化失败: {str(e)}") from e |         raise Exception(f"设备参数JSON序列化失败: {str(e)}") from e | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         if conn: |         if conn: | ||||||
|             conn.rollback() |             conn.rollback() | ||||||
| @ -183,14 +237,17 @@ async def create_device(device_data: DeviceCreateRequest, request: Request):  # | |||||||
|     finally: |     finally: | ||||||
|         db.close_connection(conn, cursor) |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 获取设备列表接口 | ||||||
|  | # ------------------------------ | ||||||
| @router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)") | @router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)") | ||||||
| async def get_device_list( | async def get_device_list( | ||||||
|         page: int = Query(1, ge=1, description="页码、默认第1页"), |         page: int = Query(1, ge=1, description="页码,默认第1页"), | ||||||
|         page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"), |         page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"), | ||||||
|         device_type: str = Query(None, description="按设备类型筛选"), |         device_type: str = Query(None, description="按设备类型筛选"), | ||||||
|         online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选") |         online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选") | ||||||
| ): | ): | ||||||
|     # 原有代码保持不变 |  | ||||||
|     conn = None |     conn = None | ||||||
|     cursor = None |     cursor = None | ||||||
|     try: |     try: | ||||||
| @ -207,12 +264,14 @@ async def get_device_list( | |||||||
|             where_clause.append("device_online_status = %s") |             where_clause.append("device_online_status = %s") | ||||||
|             params.append(online_status) |             params.append(online_status) | ||||||
|  |  | ||||||
|  |         # 统计总数 | ||||||
|         count_query = "SELECT COUNT(*) AS total FROM devices" |         count_query = "SELECT COUNT(*) AS total FROM devices" | ||||||
|         if where_clause: |         if where_clause: | ||||||
|             count_query += " WHERE " + " AND ".join(where_clause) |             count_query += " WHERE " + " AND ".join(where_clause) | ||||||
|         cursor.execute(count_query, params) |         cursor.execute(count_query, params) | ||||||
|         total = cursor.fetchone()["total"] |         total = cursor.fetchone()["total"] | ||||||
|  |  | ||||||
|  |         # 分页查询列表 | ||||||
|         offset = (page - 1) * page_size |         offset = (page - 1) * page_size | ||||||
|         list_query = "SELECT * FROM devices" |         list_query = "SELECT * FROM devices" | ||||||
|         if where_clause: |         if where_clause: | ||||||
| @ -238,23 +297,140 @@ async def get_device_list( | |||||||
|         db.close_connection(conn, cursor) |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_unique_client_ips() -> list[str]: | # ------------------------------ | ||||||
|     """ | # 获取设备上下线记录接口 | ||||||
|     获取所有去重的客户端IP列表 | # ------------------------------ | ||||||
|  | @router.get("/{device_id}/status-history", response_model=APIResponse, summary="获取设备上下线记录") | ||||||
|     :return: 去重后的客户端IP字符串列表,如果没有数据则返回空列表 | async def get_device_status_history( | ||||||
|     """ |         device_id: int = Path(..., description="设备ID"), | ||||||
|  |         page: int = Query(1, ge=1, description="页码,默认第1页"), | ||||||
|  |         page_size: int = Query(10, ge=1, le=100, description="每页条数,1-100之间"), | ||||||
|  |         start_date: date = Query(None, description="开始日期,格式YYYY-MM-DD"), | ||||||
|  |         end_date: date = Query(None, description="结束日期,格式YYYY-MM-DD") | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 检查设备是否存在并获取 client_ip | ||||||
|  |         cursor.execute("SELECT id, client_ip FROM devices WHERE id = %s", (device_id,)) | ||||||
|  |         device = cursor.fetchone() | ||||||
|  |         if not device: | ||||||
|  |             raise HTTPException(status_code=404, detail=f"设备ID为 {device_id} 的设备不存在") | ||||||
|  |         client_ip = device['client_ip'] | ||||||
|  |  | ||||||
|  |         where_clause = ["client_ip = %s"] | ||||||
|  |         params = [client_ip] | ||||||
|  |  | ||||||
|  |         # 日期筛选 | ||||||
|  |         if start_date: | ||||||
|  |             where_clause.append("DATE(created_at) >= %s") | ||||||
|  |             params.append(start_date.strftime("%Y-%m-%d")) | ||||||
|  |         if end_date: | ||||||
|  |             where_clause.append("DATE(created_at) <= %s") | ||||||
|  |             params.append(end_date.strftime("%Y-%m-%d")) | ||||||
|  |  | ||||||
|  |         # 统计记录总数 | ||||||
|  |         count_query = "SELECT COUNT(*) AS total FROM device_action WHERE " + " AND ".join(where_clause) | ||||||
|  |         cursor.execute(count_query, params) | ||||||
|  |         total = cursor.fetchone()["total"] | ||||||
|  |  | ||||||
|  |         # 分页查询记录 | ||||||
|  |         offset = (page - 1) * page_size | ||||||
|  |         list_query = f""" | ||||||
|  |             SELECT * FROM device_action  | ||||||
|  |             WHERE {' AND '.join(where_clause)} | ||||||
|  |             ORDER BY created_at DESC  | ||||||
|  |             LIMIT %s OFFSET %s | ||||||
|  |         """ | ||||||
|  |         params.extend([page_size, offset]) | ||||||
|  |         cursor.execute(list_query, params) | ||||||
|  |         history_list = cursor.fetchall() | ||||||
|  |  | ||||||
|  |         # 格式化为响应模型结构 | ||||||
|  |         formatted_history = [] | ||||||
|  |         for item in history_list: | ||||||
|  |             formatted_item = { | ||||||
|  |                 "id": item["id"], | ||||||
|  |                 "device_id": device_id, | ||||||
|  |                 "client_ip": item["client_ip"], | ||||||
|  |                 "status": item["action"], | ||||||
|  |                 "status_time": item["created_at"] | ||||||
|  |             } | ||||||
|  |             formatted_history.append(formatted_item) | ||||||
|  |  | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message="获取设备上下线记录成功", | ||||||
|  |             data=DeviceStatusHistoryListResponse( | ||||||
|  |                 total=total, | ||||||
|  |                 history=[DeviceStatusHistoryResponse(**item) for item in formatted_history] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"获取设备上下线记录失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 手动更新设备在线状态接口 | ||||||
|  | # ------------------------------ | ||||||
|  | @router.put("/{device_id}/status", response_model=APIResponse, summary="更新设备在线状态") | ||||||
|  | async def update_device_status( | ||||||
|  |         device_id: int = Path(..., description="设备ID"), | ||||||
|  |         status: int = Query(..., ge=0, le=1, description="在线状态(1-在线、0-离线)") | ||||||
|  | ): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|  |         # 获取设备 client_ip | ||||||
|  |         cursor.execute("SELECT id, client_ip FROM devices WHERE id = %s", (device_id,)) | ||||||
|  |         device = cursor.fetchone() | ||||||
|  |         if not device: | ||||||
|  |             raise HTTPException(status_code=404, detail=f"设备ID为 {device_id} 的设备不存在") | ||||||
|  |  | ||||||
|  |         # 更新状态 | ||||||
|  |         success = update_online_status_by_ip(device['client_ip'], status) | ||||||
|  |  | ||||||
|  |         if success: | ||||||
|  |             status_text = "在线" if status == 1 else "离线" | ||||||
|  |             return APIResponse( | ||||||
|  |                 code=200, | ||||||
|  |                 message=f"设备已更新为{status_text}状态", | ||||||
|  |                 data={"device_id": device_id, "status": status, "status_text": status_text} | ||||||
|  |             ) | ||||||
|  |         return APIResponse( | ||||||
|  |             code=500, | ||||||
|  |             message="更新设备状态失败", | ||||||
|  |             data=None | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         raise Exception(f"更新设备状态失败: {str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # ------------------------------ | ||||||
|  | # 获取所有去重的客户端IP列表 | ||||||
|  | # ------------------------------ | ||||||
|  | def get_unique_client_ips() -> list[str]: | ||||||
|  |     """获取所有去重的客户端IP列表""" | ||||||
|     conn = None |     conn = None | ||||||
|     cursor = None |     cursor = None | ||||||
|     try: |     try: | ||||||
|         conn = db.get_connection() |         conn = db.get_connection() | ||||||
|         cursor = conn.cursor(dictionary=True) |         cursor = conn.cursor(dictionary=True) | ||||||
|  |  | ||||||
|         # 查询去重的客户端IP |  | ||||||
|         query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL" |         query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL" | ||||||
|         cursor.execute(query) |         cursor.execute(query) | ||||||
|  |  | ||||||
|         # 提取结果并转换为字符串列表 |  | ||||||
|         results = cursor.fetchall() |         results = cursor.fetchall() | ||||||
|         return [item['client_ip'] for item in results] |         return [item['client_ip'] for item in results] | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,10 +1,13 @@ | |||||||
|  | import subprocess | ||||||
|  | import os | ||||||
|  | import sys | ||||||
|  | import shutil | ||||||
|  | import threading | ||||||
|  | from pathlib import Path | ||||||
|  | from datetime import datetime | ||||||
| from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query | from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query | ||||||
| from fastapi.responses import FileResponse | from fastapi.responses import FileResponse | ||||||
| from mysql.connector import Error as MySQLError | from mysql.connector import Error as MySQLError | ||||||
| import os |  | ||||||
| import shutil |  | ||||||
| from pathlib import Path |  | ||||||
| from datetime import datetime |  | ||||||
|  |  | ||||||
| # 复用项目依赖 | # 复用项目依赖 | ||||||
| from ds.db import db | from ds.db import db | ||||||
| @ -15,7 +18,7 @@ from schema.model_schema import ( | |||||||
|     ModelListResponse |     ModelListResponse | ||||||
| ) | ) | ||||||
| from schema.response_schema import APIResponse | from schema.response_schema import APIResponse | ||||||
| from util.model_util import load_yolo_model  # 使用修复后的模型加载工具 | from util.model_util import load_yolo_model  # 模型加载工具 | ||||||
|  |  | ||||||
| # 路径配置 | # 路径配置 | ||||||
| CURRENT_FILE_PATH = Path(__file__).resolve() | CURRENT_FILE_PATH = Path(__file__).resolve() | ||||||
| @ -28,14 +31,63 @@ DB_PATH_PREFIX_TO_REMOVE = str(PROJECT_ROOT) + os.sep | |||||||
| ALLOWED_MODEL_EXT = {"pt"} | ALLOWED_MODEL_EXT = {"pt"} | ||||||
| MAX_MODEL_SIZE = 100 * 1024 * 1024  # 100MB | MAX_MODEL_SIZE = 100 * 1024 * 1024  # 100MB | ||||||
|  |  | ||||||
| # 全局模型变量 | # 全局模型变量(带版本标识) | ||||||
| global _yolo_model | global _yolo_model, _current_model_version | ||||||
| _yolo_model = None | _yolo_model = None | ||||||
|  | _current_model_version = None  # 模型版本标识(用于检测模型是否变化) | ||||||
|  |  | ||||||
| router = APIRouter(prefix="/models", tags=["模型管理"]) | router = APIRouter(prefix="/models", tags=["模型管理"]) | ||||||
|  |  | ||||||
|  |  | ||||||
| # 工具函数:验证模型路径 | # 服务重启核心工具函数 | ||||||
|  | def restart_service(): | ||||||
|  |     """重启当前FastAPI服务进程""" | ||||||
|  |     print("\n[服务重启] 检测到默认模型更换,开始清理资源并重启...") | ||||||
|  |     try: | ||||||
|  |         # 关闭所有WebSocket连接 | ||||||
|  |         try: | ||||||
|  |             from ws import connected_clients | ||||||
|  |             if connected_clients: | ||||||
|  |                 print(f"[服务重启] 关闭{len(connected_clients)}个WebSocket旧连接") | ||||||
|  |                 for ip, conn in list(connected_clients.items()): | ||||||
|  |                     try: | ||||||
|  |                         if conn.consumer_task and not conn.consumer_task.done(): | ||||||
|  |                             conn.consumer_task.cancel() | ||||||
|  |                         conn.websocket.close(code=1001, reason="模型更新,服务重启") | ||||||
|  |                         connected_clients.pop(ip) | ||||||
|  |                     except Exception as e: | ||||||
|  |                         print(f"[服务重启] 关闭客户端{ip}连接失败:{str(e)}") | ||||||
|  |         except ImportError: | ||||||
|  |             print("[服务重启] 未找到WebSocket连接管理模块,跳过连接关闭") | ||||||
|  |  | ||||||
|  |         # 关闭数据库连接 | ||||||
|  |         if hasattr(db, "close_all_connections"): | ||||||
|  |             db.close_all_connections() | ||||||
|  |         else: | ||||||
|  |             print("[警告] db模块未实现close_all_connections,可能存在连接泄漏") | ||||||
|  |  | ||||||
|  |         # 启动新进程 | ||||||
|  |         python_exec = sys.executable | ||||||
|  |         current_argv = sys.argv | ||||||
|  |         print(f"[服务重启] 启动新进程:{python_exec} {' '.join(current_argv)}") | ||||||
|  |         subprocess.Popen( | ||||||
|  |             [python_exec] + current_argv, | ||||||
|  |             close_fds=True, | ||||||
|  |             start_new_session=True, | ||||||
|  |             stdout=subprocess.PIPE, | ||||||
|  |             stderr=subprocess.PIPE | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # 退出当前进程 | ||||||
|  |         print("[服务重启] 新进程已启动,当前进程退出") | ||||||
|  |         sys.exit(0) | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"[服务重启] 重启失败:{str(e)}") | ||||||
|  |         raise HTTPException(status_code=500, detail=f"模型更换成功,但服务重启失败:{str(e)}") from e | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 模型路径验证工具函数 | ||||||
| def get_valid_model_abs_path(relative_path: str) -> str: | def get_valid_model_abs_path(relative_path: str) -> str: | ||||||
|     try: |     try: | ||||||
|         relative_path = relative_path.replace("/", os.sep) |         relative_path = relative_path.replace("/", os.sep) | ||||||
| @ -87,6 +139,49 @@ def get_valid_model_abs_path(relative_path: str) -> str: | |||||||
|         ) from e |         ) from e | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 对外提供当前模型(带版本校验) | ||||||
|  | def get_current_yolo_model(): | ||||||
|  |     """供检测模块获取当前最新默认模型(仅版本变化时重新加载)""" | ||||||
|  |     global _yolo_model, _current_model_version | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |         cursor.execute("SELECT path FROM model WHERE is_default = 1") | ||||||
|  |         default_model = cursor.fetchone() | ||||||
|  |         if not default_model: | ||||||
|  |             print("[get_current_yolo_model] 暂无默认模型") | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |         # 1. 计算当前默认模型的唯一版本标识 | ||||||
|  |         # (路径哈希 + 文件修改时间戳,确保模型变化时版本变化) | ||||||
|  |         valid_abs_path = get_valid_model_abs_path(default_model["path"]) | ||||||
|  |         model_stat = os.stat(valid_abs_path) | ||||||
|  |         model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" | ||||||
|  |  | ||||||
|  |         # 2. 版本未变化则复用已有模型(核心优化点) | ||||||
|  |         if _yolo_model and _current_model_version == model_version: | ||||||
|  |             # print(f"[get_current_yolo_model] 模型版本未变,复用缓存(版本:{_current_model_version[:10]}...)") | ||||||
|  |             return _yolo_model | ||||||
|  |  | ||||||
|  |         # 3. 版本变化时重新加载模型 | ||||||
|  |         _yolo_model = load_yolo_model(valid_abs_path) | ||||||
|  |         if _yolo_model: | ||||||
|  |             setattr(_yolo_model, "model_path", valid_abs_path) | ||||||
|  |             _current_model_version = model_version  # 更新版本标识 | ||||||
|  |             print(f"[get_current_yolo_model] 模型版本更新,重新加载(新版本:{model_version[:10]}...)") | ||||||
|  |         else: | ||||||
|  |             print(f"[get_current_yolo_model] 加载最新默认模型失败:{valid_abs_path}") | ||||||
|  |         return _yolo_model | ||||||
|  |  | ||||||
|  |     except Exception as e: | ||||||
|  |         print(f"[get_current_yolo_model] 加载失败:{str(e)}") | ||||||
|  |         return None | ||||||
|  |     finally: | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
| # 1. 上传模型 | # 1. 上传模型 | ||||||
| @router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)") | @router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)") | ||||||
| async def upload_model( | async def upload_model( | ||||||
| @ -142,12 +237,16 @@ async def upload_model( | |||||||
|         if not new_model: |         if not new_model: | ||||||
|             raise HTTPException(status_code=500, detail="上传成功但无法获取记录") |             raise HTTPException(status_code=500, detail="上传成功但无法获取记录") | ||||||
|  |  | ||||||
|         # 加载默认模型 |         # 加载默认模型并更新版本 | ||||||
|         global _yolo_model |         global _yolo_model, _current_model_version | ||||||
|         if is_default: |         if is_default: | ||||||
|             valid_abs_path = get_valid_model_abs_path(db_relative_path) |             valid_abs_path = get_valid_model_abs_path(db_relative_path) | ||||||
|             _yolo_model = load_yolo_model(valid_abs_path) |             _yolo_model = load_yolo_model(valid_abs_path) | ||||||
|             if not _yolo_model: |             if _yolo_model: | ||||||
|  |                 setattr(_yolo_model, "model_path", valid_abs_path) | ||||||
|  |                 model_stat = os.stat(valid_abs_path) | ||||||
|  |                 _current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" | ||||||
|  |             else: | ||||||
|                 raise HTTPException( |                 raise HTTPException( | ||||||
|                     status_code=500, |                     status_code=500, | ||||||
|                     detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path})" |                     detail=f"上传成功,但加载默认模型失败(路径:{valid_abs_path})" | ||||||
| @ -246,11 +345,15 @@ async def get_default_model(): | |||||||
|             raise HTTPException(status_code=404, detail="暂无默认模型") |             raise HTTPException(status_code=404, detail="暂无默认模型") | ||||||
|  |  | ||||||
|         valid_abs_path = get_valid_model_abs_path(default_model["path"]) |         valid_abs_path = get_valid_model_abs_path(default_model["path"]) | ||||||
|         global _yolo_model |         global _yolo_model, _current_model_version | ||||||
|  |  | ||||||
|         if not _yolo_model: |         if not _yolo_model: | ||||||
|             _yolo_model = load_yolo_model(valid_abs_path) |             _yolo_model = load_yolo_model(valid_abs_path) | ||||||
|             if not _yolo_model: |             if _yolo_model: | ||||||
|  |                 setattr(_yolo_model, "model_path", valid_abs_path) | ||||||
|  |                 model_stat = os.stat(valid_abs_path) | ||||||
|  |                 _current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" | ||||||
|  |             else: | ||||||
|                 raise HTTPException( |                 raise HTTPException( | ||||||
|                     status_code=500, |                     status_code=500, | ||||||
|                     detail=f"默认模型存在,但加载失败(路径:{valid_abs_path})" |                     detail=f"默认模型存在,但加载失败(路径:{valid_abs_path})" | ||||||
| @ -358,11 +461,16 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest): | |||||||
|         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | ||||||
|         updated_model = cursor.fetchone() |         updated_model = cursor.fetchone() | ||||||
|  |  | ||||||
|         global _yolo_model |         # 更新模型后重置版本标识 | ||||||
|  |         global _yolo_model, _current_model_version | ||||||
|         if need_load_default: |         if need_load_default: | ||||||
|             valid_abs_path = get_valid_model_abs_path(updated_model["path"]) |             valid_abs_path = get_valid_model_abs_path(updated_model["path"]) | ||||||
|             _yolo_model = load_yolo_model(valid_abs_path) |             _yolo_model = load_yolo_model(valid_abs_path) | ||||||
|             if not _yolo_model: |             if _yolo_model: | ||||||
|  |                 setattr(_yolo_model, "model_path", valid_abs_path) | ||||||
|  |                 model_stat = os.stat(valid_abs_path) | ||||||
|  |                 _current_model_version = f"{hash(valid_abs_path)}_{model_stat.st_mtime_ns}" | ||||||
|  |             else: | ||||||
|                 raise HTTPException( |                 raise HTTPException( | ||||||
|                     status_code=500, |                     status_code=500, | ||||||
|                     detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path})" |                     detail=f"更新成功,但加载新默认模型失败(路径:{valid_abs_path})" | ||||||
| @ -382,6 +490,96 @@ async def update_model(model_id: int, model_update: ModelUpdateRequest): | |||||||
|         db.close_connection(conn, cursor) |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # 5.1 更换默认模型(自动重启服务) | ||||||
|  | @router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型(自动重启服务)") | ||||||
|  | async def set_default_model(model_id: int): | ||||||
|  |     conn = None | ||||||
|  |     cursor = None | ||||||
|  |     try: | ||||||
|  |         conn = db.get_connection() | ||||||
|  |         cursor = conn.cursor(dictionary=True) | ||||||
|  |         conn.autocommit = False  # 开启事务 | ||||||
|  |  | ||||||
|  |         # 1. 校验目标模型是否存在 | ||||||
|  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | ||||||
|  |         target_model = cursor.fetchone() | ||||||
|  |         if not target_model: | ||||||
|  |             raise HTTPException(status_code=404, detail=f"目标模型不存在!ID:{model_id}") | ||||||
|  |  | ||||||
|  |         # 2. 检查是否已为默认模型 | ||||||
|  |         if target_model["is_default"]: | ||||||
|  |             return APIResponse( | ||||||
|  |                 code=200, | ||||||
|  |                 message=f"模型ID:{model_id} 已是默认模型,无需更换和重启", | ||||||
|  |                 data=ModelResponse(**target_model) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         # 3. 校验目标模型文件合法性 | ||||||
|  |         try: | ||||||
|  |             valid_abs_path = get_valid_model_abs_path(target_model["path"]) | ||||||
|  |         except HTTPException as e: | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=400, | ||||||
|  |                 detail=f"目标模型文件非法,无法设为默认:{e.detail}" | ||||||
|  |             ) from e | ||||||
|  |  | ||||||
|  |         # 4. 数据库事务:更新默认模型状态 | ||||||
|  |         try: | ||||||
|  |             cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP") | ||||||
|  |             cursor.execute( | ||||||
|  |                 "UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s", | ||||||
|  |                 (model_id,) | ||||||
|  |             ) | ||||||
|  |             conn.commit() | ||||||
|  |         except MySQLError as e: | ||||||
|  |             conn.rollback() | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=500, | ||||||
|  |                 detail=f"更新默认模型状态失败(已回滚):{str(e)}" | ||||||
|  |             ) from e | ||||||
|  |  | ||||||
|  |         # 5. 验证新模型可加载性 | ||||||
|  |         test_model = load_yolo_model(valid_abs_path) | ||||||
|  |         if not test_model: | ||||||
|  |             conn.rollback() | ||||||
|  |             raise HTTPException( | ||||||
|  |                 status_code=500, | ||||||
|  |                 detail=f"新默认模型加载失败,已回滚状态(路径:{valid_abs_path})" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         # 6. 重新查询更新后的模型信息 | ||||||
|  |         cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,)) | ||||||
|  |         updated_model = cursor.fetchone() | ||||||
|  |  | ||||||
|  |         # 7. 重置版本标识(关键:确保下次检测加载新模型) | ||||||
|  |         global _current_model_version | ||||||
|  |         _current_model_version = None | ||||||
|  |         print(f"[更换默认模型] 已重置模型版本标识,下次检测将加载新模型") | ||||||
|  |  | ||||||
|  |         # 8. 延迟重启服务 | ||||||
|  |         print(f"[更换默认模型] 成功!将在1秒后重启服务以应用新模型(ID:{model_id})") | ||||||
|  |         threading.Timer( | ||||||
|  |             interval=1.0, | ||||||
|  |             function=restart_service | ||||||
|  |         ).start() | ||||||
|  |  | ||||||
|  |         # 9. 返回成功响应 | ||||||
|  |         return APIResponse( | ||||||
|  |             code=200, | ||||||
|  |             message=f"已成功更换默认模型(ID:{model_id})!服务将在1秒后自动重启以应用新模型", | ||||||
|  |             data=ModelResponse(**updated_model) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     except MySQLError as e: | ||||||
|  |         if conn: | ||||||
|  |             conn.rollback() | ||||||
|  |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | ||||||
|  |     finally: | ||||||
|  |         if conn: | ||||||
|  |             conn.autocommit = True | ||||||
|  |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
| # 6. 删除模型 | # 6. 删除模型 | ||||||
| @router.delete("/{model_id}", response_model=APIResponse, summary="删除模型") | @router.delete("/{model_id}", response_model=APIResponse, summary="删除模型") | ||||||
| async def delete_model(model_id: int): | async def delete_model(model_id: int): | ||||||
| @ -420,10 +618,12 @@ async def delete_model(model_id: int): | |||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             extra_msg = f"(文件删除失败:{str(e)})" |             extra_msg = f"(文件删除失败:{str(e)})" | ||||||
|  |  | ||||||
|         global _yolo_model |         # 如果删除的是当前加载的模型,重置缓存 | ||||||
|         if _yolo_model and str(_yolo_model.model_path) == model_abs_path_str: |         global _yolo_model, _current_model_version | ||||||
|  |         if _yolo_model and str(getattr(_yolo_model, "model_path", "")) == model_abs_path_str: | ||||||
|             _yolo_model = None |             _yolo_model = None | ||||||
|             print(f"[模型删除] 已清空全局模型(路径:{model_abs_path_str})") |             _current_model_version = None | ||||||
|  |             print(f"[模型删除] 已清空全局模型缓存(路径:{model_abs_path_str})") | ||||||
|  |  | ||||||
|         return APIResponse( |         return APIResponse( | ||||||
|             code=200, |             code=200, | ||||||
| @ -466,32 +666,3 @@ async def download_model(model_id: int): | |||||||
|         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e |         raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e | ||||||
|     finally: |     finally: | ||||||
|         db.close_connection(conn, cursor) |         db.close_connection(conn, cursor) | ||||||
|  |  | ||||||
|  |  | ||||||
| # 对外提供当前模型 |  | ||||||
| def get_current_yolo_model(): |  | ||||||
|     """供检测模块获取当前加载的模型""" |  | ||||||
|     global _yolo_model |  | ||||||
|     if not _yolo_model: |  | ||||||
|         conn = None |  | ||||||
|         cursor = None |  | ||||||
|         try: |  | ||||||
|             conn = db.get_connection() |  | ||||||
|             cursor = conn.cursor(dictionary=True) |  | ||||||
|             cursor.execute("SELECT path FROM model WHERE is_default = 1") |  | ||||||
|             default_model = cursor.fetchone() |  | ||||||
|             if not default_model: |  | ||||||
|                 print("[get_current_yolo_model] 暂无默认模型") |  | ||||||
|                 return None |  | ||||||
|  |  | ||||||
|             valid_abs_path = get_valid_model_abs_path(default_model["path"]) |  | ||||||
|             _yolo_model = load_yolo_model(valid_abs_path) |  | ||||||
|             if _yolo_model: |  | ||||||
|                 print(f"[get_current_yolo_model] 自动加载默认模型成功") |  | ||||||
|             else: |  | ||||||
|                 print(f"[get_current_yolo_model] 自动加载默认模型失败") |  | ||||||
|         except Exception as e: |  | ||||||
|             print(f"[get_current_yolo_model] 加载失败:{str(e)}") |  | ||||||
|         finally: |  | ||||||
|             db.close_connection(conn, cursor) |  | ||||||
|     return _yolo_model |  | ||||||
|  | |||||||
| @ -12,7 +12,7 @@ def save_face_to_up_images( | |||||||
| ) -> Dict[str, str]: | ) -> Dict[str, str]: | ||||||
|     """ |     """ | ||||||
|     保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径 |     保存人脸图片到 `/up_images/用户IP/人脸名字/` 路径 | ||||||
|     修复路径计算错误,确保所有路径在up_images根目录下 |     修复路径计算错误,确保所有路径在up_images根目录下,且统一使用正斜杠 | ||||||
|  |  | ||||||
|     参数: |     参数: | ||||||
|         client_ip: 客户端IP(原始格式,如192.168.1.101) |         client_ip: 客户端IP(原始格式,如192.168.1.101) | ||||||
| @ -38,7 +38,7 @@ def save_face_to_up_images( | |||||||
|         safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|'])  # 过滤非法字符 |         safe_face_name = "".join([c for c in safe_face_name if c not in r'\/:*?"<>|'])  # 过滤非法字符 | ||||||
|  |  | ||||||
|         # 3. 构建根目录(强制转为绝对路径,避免相对路径混淆) |         # 3. 构建根目录(强制转为绝对路径,避免相对路径混淆) | ||||||
|         root_dir = Path("up_images").resolve()  # 转为绝对路径(关键修复!) |         root_dir = Path("up_images").resolve()  # 转为绝对路径 | ||||||
|         if not root_dir.exists(): |         if not root_dir.exists(): | ||||||
|             root_dir.mkdir(parents=True, exist_ok=True) |             root_dir.mkdir(parents=True, exist_ok=True) | ||||||
|             print(f"[FileUtil] 已创建up_images根目录:{root_dir}") |             print(f"[FileUtil] 已创建up_images根目录:{root_dir}") | ||||||
| @ -53,15 +53,16 @@ def save_face_to_up_images( | |||||||
|         timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3] |         timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3] | ||||||
|         image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}" |         image_filename = f"face_{safe_ip}_{timestamp}.{image_format.lower()}" | ||||||
|  |  | ||||||
|         # 6. 计算路径(关键修复:确保所有路径都是绝对路径且在root_dir下) |         # 6. 计算路径(确保所有路径都是绝对路径且在root_dir下) | ||||||
|         local_abs_path = face_name_dir / image_filename  # 绝对路径 |         local_abs_path = face_name_dir / image_filename  # 绝对路径 | ||||||
|  |  | ||||||
|         # 验证路径是否在root_dir下(防止路径穿越攻击) |         # 验证路径是否在root_dir下(防止路径穿越攻击) | ||||||
|         if not local_abs_path.resolve().is_relative_to(root_dir.resolve()): |         if not local_abs_path.resolve().is_relative_to(root_dir.resolve()): | ||||||
|             raise Exception(f"图片路径不在up_images根目录下(安全校验失败):{local_abs_path}") |             raise Exception(f"图片路径不在up_images根目录下(安全校验失败):{local_abs_path}") | ||||||
|  |  | ||||||
|         # 数据库存储路径:从root_dir开始的相对路径(如 up_images/192_168_110_31/小王/xxx.jpg) |         # 数据库存储路径:从root_dir开始的相对路径,强制替换为正斜杠 | ||||||
|         db_path = str(root_dir.name / local_abs_path.relative_to(root_dir)) |         relative_path = local_abs_path.relative_to(root_dir) | ||||||
|  |         db_path = str(relative_path).replace("\\", "/")  # 关键修复:统一使用正斜杠 | ||||||
|  |  | ||||||
|         # 7. 写入图片文件 |         # 7. 写入图片文件 | ||||||
|         with open(local_abs_path, "wb") as f: |         with open(local_abs_path, "wb") as f: | ||||||
| @ -72,7 +73,7 @@ def save_face_to_up_images( | |||||||
|  |  | ||||||
|         return { |         return { | ||||||
|             "success": True, |             "success": True, | ||||||
|             "db_path": db_path,  # 存数据库的相对路径(up_images开头) |             "db_path": db_path,  # 存数据库的相对路径(使用正斜杠) | ||||||
|             "local_abs_path": str(local_abs_path),  # 本地绝对路径 |             "local_abs_path": str(local_abs_path),  # 本地绝对路径 | ||||||
|             "msg": "图片保存成功" |             "msg": "图片保存成功" | ||||||
|         } |         } | ||||||
| @ -80,4 +81,4 @@ def save_face_to_up_images( | |||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         error_msg = f"图片保存失败:{str(e)}" |         error_msg = f"图片保存失败:{str(e)}" | ||||||
|         print(f"[FileUtil] 错误:{error_msg}") |         print(f"[FileUtil] 错误:{error_msg}") | ||||||
|         return {"success": False, "db_path": "", "local_abs_path": "", "msg": error_msg} |         return {"success": False, "db_path": "", "local_abs_path": "", "msg": error_msg} | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user