62 lines
2.1 KiB
Python
62 lines
2.1 KiB
Python
|
import os
|
|||
|
import numpy as np
|
|||
|
import traceback
|
|||
|
from ultralytics import YOLO
|
|||
|
from typing import Optional
|
|||
|
|
|||
|
|
|||
|
def load_yolo_model(model_path: str) -> Optional[YOLO]:
|
|||
|
"""
|
|||
|
加载YOLO模型(支持v5/v8),并校验模型有效性
|
|||
|
:param model_path: 模型文件的绝对路径
|
|||
|
:return: 加载成功返回YOLO模型实例,失败返回None
|
|||
|
"""
|
|||
|
try:
|
|||
|
# 加载前的基础信息检查
|
|||
|
print(f"\n[模型工具] 开始加载模型:{model_path}")
|
|||
|
print(f"[模型工具] 文件是否存在:{os.path.exists(model_path)}")
|
|||
|
if os.path.exists(model_path):
|
|||
|
print(f"[模型工具] 文件大小:{os.path.getsize(model_path) / 1024 / 1024:.2f} MB")
|
|||
|
|
|||
|
# 强制重新加载模型,避免缓存问题
|
|||
|
model = YOLO(model_path)
|
|||
|
|
|||
|
# 兼容性校验:使用numpy空数组测试模型
|
|||
|
dummy_image = np.zeros((640, 640, 3), dtype=np.uint8)
|
|||
|
|
|||
|
try:
|
|||
|
# 优先使用新版本参数
|
|||
|
model.predict(
|
|||
|
source=dummy_image,
|
|||
|
imgsz=640,
|
|||
|
conf=0.25,
|
|||
|
verbose=False,
|
|||
|
stream=False
|
|||
|
)
|
|||
|
except Exception as pred_e:
|
|||
|
print(f"[模型工具] 预测校验兼容处理:{str(pred_e)}")
|
|||
|
# 兼容旧版本YOLO参数
|
|||
|
model.predict(
|
|||
|
img=dummy_image,
|
|||
|
imgsz=640,
|
|||
|
conf=0.25,
|
|||
|
verbose=False
|
|||
|
)
|
|||
|
|
|||
|
# 验证模型基本属性
|
|||
|
if not hasattr(model, 'names'):
|
|||
|
print("[模型工具] 警告:模型缺少类别名称属性")
|
|||
|
else:
|
|||
|
print(f"[模型工具] 模型包含类别:{list(model.names.values())[:5]}...") # 显示前5个类别
|
|||
|
|
|||
|
print(f"[模型工具] 模型加载成功!")
|
|||
|
return model
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
# 详细错误信息输出
|
|||
|
print(f"\n[模型工具] 加载模型失败!路径:{model_path}")
|
|||
|
print(f"[模型工具] 异常类型:{type(e).__name__}")
|
|||
|
print(f"[模型工具] 异常详情:{str(e)}")
|
|||
|
print(f"[模型工具] 堆栈跟踪:\n{traceback.format_exc()}")
|
|||
|
return None
|