131 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			131 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | from http.client import HTTPException | |||
|  | 
 | |||
|  | import numpy as np | |||
|  | import torch | |||
|  | from MySQLdb import MySQLError | |||
|  | from ultralytics import YOLO | |||
|  | import os | |||
|  | 
 | |||
|  | from ds.db import db | |||
|  | from service.file_service import get_absolute_path | |||
|  | 
 | |||
|  | # 全局变量 | |||
|  | current_yolo_model = None | |||
|  | current_model_absolute_path = None  # 存储模型绝对路径,不依赖model实例 | |||
|  | 
 | |||
|  | ALLOWED_MODEL_EXT = {"pt"} | |||
|  | MAX_MODEL_SIZE = 100 * 1024 * 1024  # 100MB | |||
|  | 
 | |||
|  | 
 | |||
|  | def load_yolo_model(): | |||
|  |     """加载模型并存储绝对路径""" | |||
|  |     global current_yolo_model, current_model_absolute_path | |||
|  |     model_rel_path = get_enabled_model_rel_path() | |||
|  |     print(f"[模型初始化] 加载模型:{model_rel_path}") | |||
|  | 
 | |||
|  |     # 计算并存储绝对路径 | |||
|  |     current_model_absolute_path = get_absolute_path(model_rel_path) | |||
|  |     print(f"[模型初始化] 绝对路径:{current_model_absolute_path}") | |||
|  | 
 | |||
|  |     # 检查模型文件 | |||
|  |     if not os.path.exists(current_model_absolute_path): | |||
|  |         raise FileNotFoundError(f"模型文件不存在: {current_model_absolute_path}") | |||
|  | 
 | |||
|  |     try: | |||
|  |         new_model = YOLO(current_model_absolute_path) | |||
|  |         if torch.cuda.is_available(): | |||
|  |             new_model.to('cuda') | |||
|  |             print("模型已移动到GPU") | |||
|  |         else: | |||
|  |             print("使用CPU进行推理") | |||
|  |         current_yolo_model = new_model | |||
|  |         print(f"成功加载模型: {current_model_absolute_path}") | |||
|  |         return current_yolo_model | |||
|  |     except Exception as e: | |||
|  |         print(f"模型加载失败:{str(e)}") | |||
|  |         raise | |||
|  | 
 | |||
|  | 
 | |||
|  | def get_current_model(): | |||
|  |     """获取当前模型实例""" | |||
|  |     if current_yolo_model is None: | |||
|  |         raise ValueError("尚未加载任何YOLO模型,请先调用load_yolo_model加载模型") | |||
|  |     return current_yolo_model | |||
|  | 
 | |||
|  | 
 | |||
|  | def detect(image_np, conf_threshold=0.8): | |||
|  |     # 1. 输入格式验证 | |||
|  |     if not isinstance(image_np, np.ndarray): | |||
|  |         raise ValueError("输入必须是numpy数组(BGR图像)") | |||
|  |     if image_np.ndim != 3 or image_np.shape[-1] != 3: | |||
|  |         raise ValueError(f"输入图像格式错误,需为 (h, w, 3) 的BGR数组,当前shape: {image_np.shape}") | |||
|  |     detection_results = [] | |||
|  |     try: | |||
|  |         model = get_current_model() | |||
|  |         if not current_model_absolute_path: | |||
|  |             raise RuntimeError("模型未初始化!请先调用 load_yolo_model 加载模型") | |||
|  |         device = "cuda" if torch.cuda.is_available() else "cpu" | |||
|  |         print(f"检测设备:{device} | 置信度阈值:{conf_threshold}") | |||
|  | 
 | |||
|  |         # 图像尺寸信息 | |||
|  |         img_height, img_width = image_np.shape[:2] | |||
|  |         print(f"输入图像尺寸:{img_width}x{img_height}") | |||
|  | 
 | |||
|  |         # YOLO检测 | |||
|  |         print("执行YOLO检测") | |||
|  |         results = model.predict( | |||
|  |             image_np, | |||
|  |             conf=conf_threshold, | |||
|  |             device=device, | |||
|  |             show=False, | |||
|  |         ) | |||
|  | 
 | |||
|  |         # 4. 整理检测结果(仅保留Chest类别,ID=2) | |||
|  |         for box in results[0].boxes: | |||
|  |             class_id = int(box.cls[0])  # 类别ID | |||
|  |             class_name = model.names[class_id] | |||
|  |             confidence = float(box.conf[0]) | |||
|  |             bbox = tuple(map(int, box.xyxy[0])) | |||
|  | 
 | |||
|  |             # 过滤条件:置信度达标 + 类别为Chest(class_id=2) | |||
|  |             # and class_id == 2 | |||
|  |             if confidence >= conf_threshold: | |||
|  |                 detection_results.append({ | |||
|  |                     "class": class_name, | |||
|  |                     "confidence": confidence, | |||
|  |                     "bbox": bbox | |||
|  |                 }) | |||
|  | 
 | |||
|  |         # 判断是否有目标 | |||
|  |         has_content = len(detection_results) > 0 | |||
|  |         return has_content, detection_results | |||
|  | 
 | |||
|  |     except Exception as e: | |||
|  |         error_msg = f"检测过程出错:{str(e)}" | |||
|  |         print(error_msg) | |||
|  |         return False, None | |||
|  | 
 | |||
|  | 
 | |||
|  | def get_enabled_model_rel_path(): | |||
|  |     """获取数据库中启用的模型相对路径""" | |||
|  |     conn = None | |||
|  |     cursor = None | |||
|  |     try: | |||
|  |         conn = db.get_connection() | |||
|  |         cursor = conn.cursor(dictionary=True) | |||
|  |         query = "SELECT path FROM model WHERE is_default = 1 LIMIT 1" | |||
|  |         cursor.execute(query) | |||
|  |         result = cursor.fetchone() | |||
|  | 
 | |||
|  |         if not result or not result.get('path'): | |||
|  |             raise HTTPException(status_code=404, detail="未找到启用的默认模型") | |||
|  | 
 | |||
|  |         return result['path'] | |||
|  |     except MySQLError as e: | |||
|  |         raise HTTPException(status_code=500, detail=f"查询默认模型时发生数据库错误:{str(e)}") from e | |||
|  |     except Exception as e: | |||
|  |         if isinstance(e, HTTPException): | |||
|  |             raise e | |||
|  |         raise HTTPException(status_code=500, detail=f"获取默认模型路径失败:{str(e)}") from e | |||
|  |     finally: | |||
|  |         db.close_connection(conn, cursor) |