from fastapi import HTTPException import numpy as np import torch from MySQLdb import MySQLError from ultralytics import YOLO import os from ds.db import db from service.file_service import get_absolute_path # 全局变量:初始化时为None,无模型时保持None current_yolo_model = None current_model_absolute_path = None # 存储模型绝对路径,不依赖model实例 ALLOWED_MODEL_EXT = {"pt"} MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB def load_yolo_model(): """ 加载模型并存储绝对路径 无有效模型路径/模型文件不存在/加载失败时,跳过加载(不抛出异常) """ global current_yolo_model, current_model_absolute_path # 1. 获取数据库中的模型路径(无模型时返回None) model_rel_path = get_enabled_model_rel_path() # 2. 无模型路径时,跳过加载 if not model_rel_path: print("[模型初始化] 未获取到有效模型路径,已跳过模型加载") current_yolo_model = None current_model_absolute_path = None return None # 3. 有模型路径时,执行正常加载流程 print(f"[模型初始化] 加载模型:{model_rel_path}") try: # 计算绝对路径(避免路径处理异常) current_model_absolute_path = get_absolute_path(model_rel_path) print(f"[模型初始化] 模型绝对路径:{current_model_absolute_path}") # 检查模型文件是否存在 if not os.path.exists(current_model_absolute_path): print(f"[模型初始化] 警告:模型文件不存在({current_model_absolute_path}),已跳过加载") current_yolo_model = None current_model_absolute_path = None return None # 加载YOLO模型 new_model = YOLO(current_model_absolute_path) # 设备分配(GPU/CPU) if torch.cuda.is_available(): new_model.to('cuda') print("[模型初始化] 模型已移动到GPU设备") else: print("[模型初始化] 未检测到GPU,使用CPU进行推理") # 更新全局模型变量 current_yolo_model = new_model print(f"[模型初始化] 成功加载模型:{current_model_absolute_path}") return current_yolo_model # 捕获所有加载异常,避免中断项目启动 except Exception as e: print(f"[模型初始化] 警告:模型加载失败({str(e)}),已跳过加载") current_yolo_model = None current_model_absolute_path = None return None def get_current_model(): """ 获取当前模型实例 无模型时返回None(不抛出异常,避免中断流程) """ return current_yolo_model def detect(image_np, conf_threshold=0.8): """ 执行YOLO检测 无模型时返回明确提示,不崩溃;有模型时正常返回检测结果 """ # 优先检查模型是否已加载 model = get_current_model() if not model: error_msg = "检测失败:未加载任何YOLO模型(数据库中无默认模型或模型加载失败)" print(f"[检测流程] {error_msg}") return False, error_msg # 返回False+错误提示,而非None # 2. 输入格式验证(保留原逻辑,格式错误仍抛异常,属于参数问题) if not isinstance(image_np, np.ndarray): raise ValueError("输入必须是numpy数组(BGR图像格式)") if image_np.ndim != 3 or image_np.shape[-1] != 3: raise ValueError(f"输入图像格式错误,需为 (高度, 宽度, 3) 的BGR数组,当前shape: {image_np.shape}") detection_results = [] try: # 3. 检测配置 device = "cuda" if torch.cuda.is_available() else "cpu" img_height, img_width = image_np.shape[:2] print(f"[检测流程] 设备:{device} | 置信度阈值:{conf_threshold} | 图像尺寸:{img_width}x{img_height}") # 4. 执行YOLO预测 print("[检测流程] 开始执行YOLO检测") results = model.predict( image_np, conf=conf_threshold, device=device, show=False, # 不显示检测窗口 verbose=False # 关闭YOLO内部日志(可选,减少冗余输出) ) # 5. 整理检测结果(仅保留置信度达标结果,原逻辑保留) for box in results[0].boxes: class_id = int(box.cls[0]) class_name = model.names[class_id] confidence = float(box.conf[0]) # 转换为整数坐标(x1, y1, x2, y2) bbox = tuple(map(int, box.xyxy[0])) # 过滤条件:置信度达标 if confidence >= conf_threshold and 0 <= class_id <= 5: detection_results.append({ "class": class_name, "confidence": round(confidence, 4), # 保留4位小数,优化输出 "bbox": bbox }) # 6. 判断是否检测到目标 has_content = len(detection_results) > 0 print(f"[检测流程] 检测完成:共检测到 {len(detection_results)} 个目标") return has_content, detection_results # 7. 捕获检测过程异常,返回明确错误信息 except Exception as e: error_msg = f"检测过程出错:{str(e)}" print(f"[检测流程] {error_msg}") return False, error_msg def get_enabled_model_rel_path(): """ 从数据库获取启用的默认模型相对路径 无模型/数据库错误时返回None,仅记录警告日志 """ conn = None cursor = None try: # 建立数据库连接 conn = db.get_connection() cursor = conn.cursor(dictionary=True) # 查询默认模型(is_default=1) query = "SELECT path FROM model WHERE is_default = 1 LIMIT 1" cursor.execute(query) result = cursor.fetchone() # 有有效路径则返回,否则返回None if result and isinstance(result.get('path'), str) and result['path'].strip(): model_path = result['path'].strip() print(f"找到默认模型路径:{model_path}") return model_path else: print("警告:未找到启用的默认模型") return None # 捕获MySQL相关错误 except MySQLError as e: print(f"警告:查询默认模型时发生数据库错误({str(e)})") return None # 捕获其他通用错误 except Exception as e: print(f"[数据库查询] 警告:获取默认模型路径失败({str(e)})") return None # 确保数据库连接和游标关闭 finally: if cursor: try: cursor.close() print("游标已关闭") except Exception as e: print(f"关闭游标时出错:{str(e)}") # 关闭连接(允许重复关闭,无需检查是否已关闭) if conn: try: conn.close() print("数据库连接已关闭") except Exception as e: print(f"关闭数据库连接时出错:{str(e)}")