62 lines
2.1 KiB
Python
62 lines
2.1 KiB
Python
import os
|
||
import numpy as np
|
||
import traceback
|
||
from ultralytics import YOLO
|
||
from typing import Optional
|
||
|
||
|
||
def load_yolo_model(model_path: str) -> Optional[YOLO]:
|
||
"""
|
||
加载YOLO模型(支持v5/v8),并校验模型有效性
|
||
:param model_path: 模型文件的绝对路径
|
||
:return: 加载成功返回YOLO模型实例,失败返回None
|
||
"""
|
||
try:
|
||
# 加载前的基础信息检查
|
||
print(f"\n[模型工具] 开始加载模型:{model_path}")
|
||
print(f"[模型工具] 文件是否存在:{os.path.exists(model_path)}")
|
||
if os.path.exists(model_path):
|
||
print(f"[模型工具] 文件大小:{os.path.getsize(model_path) / 1024 / 1024:.2f} MB")
|
||
|
||
# 强制重新加载模型,避免缓存问题
|
||
model = YOLO(model_path)
|
||
|
||
# 兼容性校验:使用numpy空数组测试模型
|
||
dummy_image = np.zeros((640, 640, 3), dtype=np.uint8)
|
||
|
||
try:
|
||
# 优先使用新版本参数
|
||
model.predict(
|
||
source=dummy_image,
|
||
imgsz=640,
|
||
conf=0.25,
|
||
verbose=False,
|
||
stream=False
|
||
)
|
||
except Exception as pred_e:
|
||
print(f"[模型工具] 预测校验兼容处理:{str(pred_e)}")
|
||
# 兼容旧版本YOLO参数
|
||
model.predict(
|
||
img=dummy_image,
|
||
imgsz=640,
|
||
conf=0.25,
|
||
verbose=False
|
||
)
|
||
|
||
# 验证模型基本属性
|
||
if not hasattr(model, 'names'):
|
||
print("[模型工具] 警告:模型缺少类别名称属性")
|
||
else:
|
||
print(f"[模型工具] 模型包含类别:{list(model.names.values())[:5]}...") # 显示前5个类别
|
||
|
||
print(f"[模型工具] 模型加载成功!")
|
||
return model
|
||
|
||
except Exception as e:
|
||
# 详细错误信息输出
|
||
print(f"\n[模型工具] 加载模型失败!路径:{model_path}")
|
||
print(f"[模型工具] 异常类型:{type(e).__name__}")
|
||
print(f"[模型工具] 异常详情:{str(e)}")
|
||
print(f"[模型工具] 堆栈跟踪:\n{traceback.format_exc()}")
|
||
return None
|