62 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			62 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | import os | |||
|  | import numpy as np | |||
|  | import traceback | |||
|  | from ultralytics import YOLO | |||
|  | from typing import Optional | |||
|  | 
 | |||
|  | 
 | |||
|  | def load_yolo_model(model_path: str) -> Optional[YOLO]: | |||
|  |     """
 | |||
|  |     加载YOLO模型(支持v5/v8),并校验模型有效性 | |||
|  |     :param model_path: 模型文件的绝对路径 | |||
|  |     :return: 加载成功返回YOLO模型实例,失败返回None | |||
|  |     """
 | |||
|  |     try: | |||
|  |         # 加载前的基础信息检查 | |||
|  |         print(f"\n[模型工具] 开始加载模型:{model_path}") | |||
|  |         print(f"[模型工具] 文件是否存在:{os.path.exists(model_path)}") | |||
|  |         if os.path.exists(model_path): | |||
|  |             print(f"[模型工具] 文件大小:{os.path.getsize(model_path) / 1024 / 1024:.2f} MB") | |||
|  | 
 | |||
|  |         # 强制重新加载模型,避免缓存问题 | |||
|  |         model = YOLO(model_path) | |||
|  | 
 | |||
|  |         # 兼容性校验:使用numpy空数组测试模型 | |||
|  |         dummy_image = np.zeros((640, 640, 3), dtype=np.uint8) | |||
|  | 
 | |||
|  |         try: | |||
|  |             # 优先使用新版本参数 | |||
|  |             model.predict( | |||
|  |                 source=dummy_image, | |||
|  |                 imgsz=640, | |||
|  |                 conf=0.25, | |||
|  |                 verbose=False, | |||
|  |                 stream=False | |||
|  |             ) | |||
|  |         except Exception as pred_e: | |||
|  |             print(f"[模型工具] 预测校验兼容处理:{str(pred_e)}") | |||
|  |             # 兼容旧版本YOLO参数 | |||
|  |             model.predict( | |||
|  |                 img=dummy_image, | |||
|  |                 imgsz=640, | |||
|  |                 conf=0.25, | |||
|  |                 verbose=False | |||
|  |             ) | |||
|  | 
 | |||
|  |         # 验证模型基本属性 | |||
|  |         if not hasattr(model, 'names'): | |||
|  |             print("[模型工具] 警告:模型缺少类别名称属性") | |||
|  |         else: | |||
|  |             print(f"[模型工具] 模型包含类别:{list(model.names.values())[:5]}...")  # 显示前5个类别 | |||
|  | 
 | |||
|  |         print(f"[模型工具] 模型加载成功!") | |||
|  |         return model | |||
|  | 
 | |||
|  |     except Exception as e: | |||
|  |         # 详细错误信息输出 | |||
|  |         print(f"\n[模型工具] 加载模型失败!路径:{model_path}") | |||
|  |         print(f"[模型工具] 异常类型:{type(e).__name__}") | |||
|  |         print(f"[模型工具] 异常详情:{str(e)}") | |||
|  |         print(f"[模型工具] 堆栈跟踪:\n{traceback.format_exc()}") | |||
|  |         return None |