from http.client import HTTPException import numpy as np import torch from MySQLdb import MySQLError from ultralytics import YOLO import os from ds.db import db from service.file_service import get_absolute_path # 全局变量 current_yolo_model = None current_model_absolute_path = None # 存储模型绝对路径,不依赖model实例 ALLOWED_MODEL_EXT = {"pt"} MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB def load_yolo_model(): """加载模型并存储绝对路径""" global current_yolo_model, current_model_absolute_path model_rel_path = get_enabled_model_rel_path() print(f"[模型初始化] 加载模型:{model_rel_path}") # 计算并存储绝对路径 current_model_absolute_path = get_absolute_path(model_rel_path) print(f"[模型初始化] 绝对路径:{current_model_absolute_path}") # 检查模型文件 if not os.path.exists(current_model_absolute_path): raise FileNotFoundError(f"模型文件不存在: {current_model_absolute_path}") try: new_model = YOLO(current_model_absolute_path) if torch.cuda.is_available(): new_model.to('cuda') print("模型已移动到GPU") else: print("使用CPU进行推理") current_yolo_model = new_model print(f"成功加载模型: {current_model_absolute_path}") return current_yolo_model except Exception as e: print(f"模型加载失败:{str(e)}") raise def get_current_model(): """获取当前模型实例""" if current_yolo_model is None: raise ValueError("尚未加载任何YOLO模型,请先调用load_yolo_model加载模型") return current_yolo_model def detect(image_np, conf_threshold=0.8): # 1. 输入格式验证 if not isinstance(image_np, np.ndarray): raise ValueError("输入必须是numpy数组(BGR图像)") if image_np.ndim != 3 or image_np.shape[-1] != 3: raise ValueError(f"输入图像格式错误,需为 (h, w, 3) 的BGR数组,当前shape: {image_np.shape}") detection_results = [] try: model = get_current_model() if not current_model_absolute_path: raise RuntimeError("模型未初始化!请先调用 load_yolo_model 加载模型") device = "cuda" if torch.cuda.is_available() else "cpu" print(f"检测设备:{device} | 置信度阈值:{conf_threshold}") # 图像尺寸信息 img_height, img_width = image_np.shape[:2] print(f"输入图像尺寸:{img_width}x{img_height}") # YOLO检测 print("执行YOLO检测") results = model.predict( image_np, conf=conf_threshold, device=device, show=False, ) # 4. 整理检测结果(仅保留Chest类别,ID=2) for box in results[0].boxes: class_id = int(box.cls[0]) # 类别ID class_name = model.names[class_id] confidence = float(box.conf[0]) bbox = tuple(map(int, box.xyxy[0])) # 过滤条件:置信度达标 + 类别为Chest(class_id=2) # and class_id == 2 if confidence >= conf_threshold: detection_results.append({ "class": class_name, "confidence": confidence, "bbox": bbox }) # 判断是否有目标 has_content = len(detection_results) > 0 return has_content, detection_results except Exception as e: error_msg = f"检测过程出错:{str(e)}" print(error_msg) return False, None def get_enabled_model_rel_path(): """获取数据库中启用的模型相对路径""" conn = None cursor = None try: conn = db.get_connection() cursor = conn.cursor(dictionary=True) query = "SELECT path FROM model WHERE is_default = 1 LIMIT 1" cursor.execute(query) result = cursor.fetchone() if not result or not result.get('path'): raise HTTPException(status_code=404, detail="未找到启用的默认模型") return result['path'] except MySQLError as e: raise HTTPException(status_code=500, detail=f"查询默认模型时发生数据库错误:{str(e)}") from e except Exception as e: if isinstance(e, HTTPException): raise e raise HTTPException(status_code=500, detail=f"获取默认模型路径失败:{str(e)}") from e finally: db.close_connection(conn, cursor)