内容安全审核
This commit is contained in:
		
							
								
								
									
										131
									
								
								service/model_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										131
									
								
								service/model_service.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,131 @@ | ||||
| from http.client import HTTPException | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
| from MySQLdb import MySQLError | ||||
| from ultralytics import YOLO | ||||
| import os | ||||
|  | ||||
| from ds.db import db | ||||
| from service.file_service import get_absolute_path | ||||
|  | ||||
| # 全局变量 | ||||
| current_yolo_model = None | ||||
| current_model_absolute_path = None  # 存储模型绝对路径,不依赖model实例 | ||||
|  | ||||
| ALLOWED_MODEL_EXT = {"pt"} | ||||
| MAX_MODEL_SIZE = 100 * 1024 * 1024  # 100MB | ||||
|  | ||||
|  | ||||
| def load_yolo_model(): | ||||
|     """加载模型并存储绝对路径""" | ||||
|     global current_yolo_model, current_model_absolute_path | ||||
|     model_rel_path = get_enabled_model_rel_path() | ||||
|     print(f"[模型初始化] 加载模型:{model_rel_path}") | ||||
|  | ||||
|     # 计算并存储绝对路径 | ||||
|     current_model_absolute_path = get_absolute_path(model_rel_path) | ||||
|     print(f"[模型初始化] 绝对路径:{current_model_absolute_path}") | ||||
|  | ||||
|     # 检查模型文件 | ||||
|     if not os.path.exists(current_model_absolute_path): | ||||
|         raise FileNotFoundError(f"模型文件不存在: {current_model_absolute_path}") | ||||
|  | ||||
|     try: | ||||
|         new_model = YOLO(current_model_absolute_path) | ||||
|         if torch.cuda.is_available(): | ||||
|             new_model.to('cuda') | ||||
|             print("模型已移动到GPU") | ||||
|         else: | ||||
|             print("使用CPU进行推理") | ||||
|         current_yolo_model = new_model | ||||
|         print(f"成功加载模型: {current_model_absolute_path}") | ||||
|         return current_yolo_model | ||||
|     except Exception as e: | ||||
|         print(f"模型加载失败:{str(e)}") | ||||
|         raise | ||||
|  | ||||
|  | ||||
| def get_current_model(): | ||||
|     """获取当前模型实例""" | ||||
|     if current_yolo_model is None: | ||||
|         raise ValueError("尚未加载任何YOLO模型,请先调用load_yolo_model加载模型") | ||||
|     return current_yolo_model | ||||
|  | ||||
|  | ||||
| def detect(image_np, conf_threshold=0.8): | ||||
|     # 1. 输入格式验证 | ||||
|     if not isinstance(image_np, np.ndarray): | ||||
|         raise ValueError("输入必须是numpy数组(BGR图像)") | ||||
|     if image_np.ndim != 3 or image_np.shape[-1] != 3: | ||||
|         raise ValueError(f"输入图像格式错误,需为 (h, w, 3) 的BGR数组,当前shape: {image_np.shape}") | ||||
|     detection_results = [] | ||||
|     try: | ||||
|         model = get_current_model() | ||||
|         if not current_model_absolute_path: | ||||
|             raise RuntimeError("模型未初始化!请先调用 load_yolo_model 加载模型") | ||||
|         device = "cuda" if torch.cuda.is_available() else "cpu" | ||||
|         print(f"检测设备:{device} | 置信度阈值:{conf_threshold}") | ||||
|  | ||||
|         # 图像尺寸信息 | ||||
|         img_height, img_width = image_np.shape[:2] | ||||
|         print(f"输入图像尺寸:{img_width}x{img_height}") | ||||
|  | ||||
|         # YOLO检测 | ||||
|         print("执行YOLO检测") | ||||
|         results = model.predict( | ||||
|             image_np, | ||||
|             conf=conf_threshold, | ||||
|             device=device, | ||||
|             show=False, | ||||
|         ) | ||||
|  | ||||
|         # 4. 整理检测结果(仅保留Chest类别,ID=2) | ||||
|         for box in results[0].boxes: | ||||
|             class_id = int(box.cls[0])  # 类别ID | ||||
|             class_name = model.names[class_id] | ||||
|             confidence = float(box.conf[0]) | ||||
|             bbox = tuple(map(int, box.xyxy[0])) | ||||
|  | ||||
|             # 过滤条件:置信度达标 + 类别为Chest(class_id=2) | ||||
|             # and class_id == 2 | ||||
|             if confidence >= conf_threshold: | ||||
|                 detection_results.append({ | ||||
|                     "class": class_name, | ||||
|                     "confidence": confidence, | ||||
|                     "bbox": bbox | ||||
|                 }) | ||||
|  | ||||
|         # 判断是否有目标 | ||||
|         has_content = len(detection_results) > 0 | ||||
|         return has_content, detection_results | ||||
|  | ||||
|     except Exception as e: | ||||
|         error_msg = f"检测过程出错:{str(e)}" | ||||
|         print(error_msg) | ||||
|         return False, None | ||||
|  | ||||
|  | ||||
| def get_enabled_model_rel_path(): | ||||
|     """获取数据库中启用的模型相对路径""" | ||||
|     conn = None | ||||
|     cursor = None | ||||
|     try: | ||||
|         conn = db.get_connection() | ||||
|         cursor = conn.cursor(dictionary=True) | ||||
|         query = "SELECT path FROM model WHERE is_default = 1 LIMIT 1" | ||||
|         cursor.execute(query) | ||||
|         result = cursor.fetchone() | ||||
|  | ||||
|         if not result or not result.get('path'): | ||||
|             raise HTTPException(status_code=404, detail="未找到启用的默认模型") | ||||
|  | ||||
|         return result['path'] | ||||
|     except MySQLError as e: | ||||
|         raise HTTPException(status_code=500, detail=f"查询默认模型时发生数据库错误:{str(e)}") from e | ||||
|     except Exception as e: | ||||
|         if isinstance(e, HTTPException): | ||||
|             raise e | ||||
|         raise HTTPException(status_code=500, detail=f"获取默认模型路径失败:{str(e)}") from e | ||||
|     finally: | ||||
|         db.close_connection(conn, cursor) | ||||
		Reference in New Issue
	
	Block a user