| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  | from fastapi import HTTPException | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | import numpy as np | 
					
						
							|  |  |  |  | import torch | 
					
						
							|  |  |  |  | from MySQLdb import MySQLError | 
					
						
							|  |  |  |  | from ultralytics import YOLO | 
					
						
							|  |  |  |  | import os | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | from ds.db import db | 
					
						
							|  |  |  |  | from service.file_service import get_absolute_path | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  | # 全局变量:初始化时为None,无模型时保持None | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  | current_yolo_model = None | 
					
						
							|  |  |  |  | current_model_absolute_path = None  # 存储模型绝对路径,不依赖model实例 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | ALLOWED_MODEL_EXT = {"pt"} | 
					
						
							|  |  |  |  | MAX_MODEL_SIZE = 100 * 1024 * 1024  # 100MB | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def load_yolo_model(): | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |     """
 | 
					
						
							|  |  |  |  |     加载模型并存储绝对路径 | 
					
						
							|  |  |  |  |     无有效模型路径/模型文件不存在/加载失败时,跳过加载(不抛出异常) | 
					
						
							|  |  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     global current_yolo_model, current_model_absolute_path | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |     # 1. 获取数据库中的模型路径(无模型时返回None) | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     model_rel_path = get_enabled_model_rel_path() | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |     # 2. 无模型路径时,跳过加载 | 
					
						
							|  |  |  |  |     if not model_rel_path: | 
					
						
							|  |  |  |  |         print("[模型初始化] 未获取到有效模型路径,已跳过模型加载") | 
					
						
							|  |  |  |  |         current_yolo_model = None | 
					
						
							|  |  |  |  |         current_model_absolute_path = None | 
					
						
							|  |  |  |  |         return None | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |     # 3. 有模型路径时,执行正常加载流程 | 
					
						
							|  |  |  |  |     print(f"[模型初始化] 加载模型:{model_rel_path}") | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         # 计算绝对路径(避免路径处理异常) | 
					
						
							|  |  |  |  |         current_model_absolute_path = get_absolute_path(model_rel_path) | 
					
						
							|  |  |  |  |         print(f"[模型初始化] 模型绝对路径:{current_model_absolute_path}") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 检查模型文件是否存在 | 
					
						
							|  |  |  |  |         if not os.path.exists(current_model_absolute_path): | 
					
						
							|  |  |  |  |             print(f"[模型初始化] 警告:模型文件不存在({current_model_absolute_path}),已跳过加载") | 
					
						
							|  |  |  |  |             current_yolo_model = None | 
					
						
							|  |  |  |  |             current_model_absolute_path = None | 
					
						
							|  |  |  |  |             return None | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 加载YOLO模型 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |         new_model = YOLO(current_model_absolute_path) | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         # 设备分配(GPU/CPU) | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |         if torch.cuda.is_available(): | 
					
						
							|  |  |  |  |             new_model.to('cuda') | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |             print("[模型初始化] 模型已移动到GPU设备") | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |             print("[模型初始化] 未检测到GPU,使用CPU进行推理") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # 更新全局模型变量 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |         current_yolo_model = new_model | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         print(f"[模型初始化] 成功加载模型:{current_model_absolute_path}") | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |         return current_yolo_model | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     # 捕获所有加载异常,避免中断项目启动 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         print(f"[模型初始化] 警告:模型加载失败({str(e)}),已跳过加载") | 
					
						
							|  |  |  |  |         current_yolo_model = None | 
					
						
							|  |  |  |  |         current_model_absolute_path = None | 
					
						
							|  |  |  |  |         return None | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def get_current_model(): | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |     """
 | 
					
						
							|  |  |  |  |     获取当前模型实例 | 
					
						
							|  |  |  |  |     无模型时返回None(不抛出异常,避免中断流程) | 
					
						
							|  |  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     return current_yolo_model | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def detect(image_np, conf_threshold=0.8): | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |     """
 | 
					
						
							|  |  |  |  |     执行YOLO检测 | 
					
						
							|  |  |  |  |     无模型时返回明确提示,不崩溃;有模型时正常返回检测结果 | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  |     # 优先检查模型是否已加载 | 
					
						
							|  |  |  |  |     model = get_current_model() | 
					
						
							|  |  |  |  |     if not model: | 
					
						
							|  |  |  |  |         error_msg = "检测失败:未加载任何YOLO模型(数据库中无默认模型或模型加载失败)" | 
					
						
							|  |  |  |  |         print(f"[检测流程] {error_msg}") | 
					
						
							|  |  |  |  |         return False, error_msg  # 返回False+错误提示,而非None | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # 2. 输入格式验证(保留原逻辑,格式错误仍抛异常,属于参数问题) | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     if not isinstance(image_np, np.ndarray): | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         raise ValueError("输入必须是numpy数组(BGR图像格式)") | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     if image_np.ndim != 3 or image_np.shape[-1] != 3: | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         raise ValueError(f"输入图像格式错误,需为 (高度, 宽度, 3) 的BGR数组,当前shape: {image_np.shape}") | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     detection_results = [] | 
					
						
							|  |  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         # 3. 检测配置 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |         device = "cuda" if torch.cuda.is_available() else "cpu" | 
					
						
							|  |  |  |  |         img_height, img_width = image_np.shape[:2] | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         print(f"[检测流程] 设备:{device} | 置信度阈值:{conf_threshold} | 图像尺寸:{img_width}x{img_height}") | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         # 4. 执行YOLO预测 | 
					
						
							|  |  |  |  |         print("[检测流程] 开始执行YOLO检测") | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |         results = model.predict( | 
					
						
							|  |  |  |  |             image_np, | 
					
						
							|  |  |  |  |             conf=conf_threshold, | 
					
						
							|  |  |  |  |             device=device, | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |             show=False,  # 不显示检测窗口 | 
					
						
							|  |  |  |  |             verbose=False  # 关闭YOLO内部日志(可选,减少冗余输出) | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |         ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         # 5. 整理检测结果(仅保留置信度达标结果,原逻辑保留) | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |         for box in results[0].boxes: | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |             class_id = int(box.cls[0]) | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |             class_name = model.names[class_id] | 
					
						
							|  |  |  |  |             confidence = float(box.conf[0]) | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |             # 转换为整数坐标(x1, y1, x2, y2) | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |             bbox = tuple(map(int, box.xyxy[0])) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |             # 过滤条件:置信度达标 | 
					
						
							|  |  |  |  |             if confidence >= conf_threshold and 0 <= class_id <= 5: | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |                 detection_results.append({ | 
					
						
							|  |  |  |  |                     "class": class_name, | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |                     "confidence": round(confidence, 4),  # 保留4位小数,优化输出 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |                     "bbox": bbox | 
					
						
							|  |  |  |  |                 }) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         # 6. 判断是否检测到目标 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |         has_content = len(detection_results) > 0 | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         print(f"[检测流程] 检测完成:共检测到 {len(detection_results)} 个目标") | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |         return has_content, detection_results | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |     # 7. 捕获检测过程异常,返回明确错误信息 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     except Exception as e: | 
					
						
							|  |  |  |  |         error_msg = f"检测过程出错:{str(e)}" | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         print(f"[检测流程] {error_msg}") | 
					
						
							|  |  |  |  |         return False, error_msg | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def get_enabled_model_rel_path(): | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |     """
 | 
					
						
							|  |  |  |  |     从数据库获取启用的默认模型相对路径 | 
					
						
							|  |  |  |  |     无模型/数据库错误时返回None,仅记录警告日志 | 
					
						
							|  |  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     conn = None | 
					
						
							|  |  |  |  |     cursor = None | 
					
						
							|  |  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         # 建立数据库连接 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |         conn = db.get_connection() | 
					
						
							|  |  |  |  |         cursor = conn.cursor(dictionary=True) | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         # 查询默认模型(is_default=1) | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |         query = "SELECT path FROM model WHERE is_default = 1 LIMIT 1" | 
					
						
							|  |  |  |  |         cursor.execute(query) | 
					
						
							|  |  |  |  |         result = cursor.fetchone() | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         # 有有效路径则返回,否则返回None | 
					
						
							|  |  |  |  |         if result and isinstance(result.get('path'), str) and result['path'].strip(): | 
					
						
							|  |  |  |  |             model_path = result['path'].strip() | 
					
						
							|  |  |  |  |             print(f"找到默认模型路径:{model_path}") | 
					
						
							|  |  |  |  |             return model_path | 
					
						
							|  |  |  |  |         else: | 
					
						
							|  |  |  |  |             print("警告:未找到启用的默认模型") | 
					
						
							|  |  |  |  |             return None | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |     # 捕获MySQL相关错误 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     except MySQLError as e: | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         print(f"警告:查询默认模型时发生数据库错误({str(e)})") | 
					
						
							|  |  |  |  |         return None | 
					
						
							|  |  |  |  |     # 捕获其他通用错误 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         print(f"[数据库查询] 警告:获取默认模型路径失败({str(e)})") | 
					
						
							|  |  |  |  |         return None | 
					
						
							|  |  |  |  |     # 确保数据库连接和游标关闭 | 
					
						
							| 
									
										
										
										
											2025-09-30 17:17:20 +08:00
										 |  |  |  |     finally: | 
					
						
							| 
									
										
										
										
											2025-10-10 11:35:37 +08:00
										 |  |  |  |         if cursor: | 
					
						
							|  |  |  |  |             try: | 
					
						
							|  |  |  |  |                 cursor.close() | 
					
						
							|  |  |  |  |                 print("游标已关闭") | 
					
						
							|  |  |  |  |             except Exception as e: | 
					
						
							|  |  |  |  |                 print(f"关闭游标时出错:{str(e)}") | 
					
						
							|  |  |  |  |         # 关闭连接(允许重复关闭,无需检查是否已关闭) | 
					
						
							|  |  |  |  |         if conn: | 
					
						
							|  |  |  |  |             try: | 
					
						
							|  |  |  |  |                 conn.close() | 
					
						
							|  |  |  |  |                 print("数据库连接已关闭") | 
					
						
							|  |  |  |  |             except Exception as e: | 
					
						
							|  |  |  |  |                 print(f"关闭数据库连接时出错:{str(e)}") |