55 lines
1.5 KiB
Python
55 lines
1.5 KiB
Python
import os
|
||
|
||
import cv2
|
||
from ultralytics import YOLO
|
||
|
||
# 全局变量
|
||
_yolo_model = None
|
||
|
||
|
||
model_path = os.path.join(os.path.dirname(__file__), "models", "best.pt")
|
||
|
||
|
||
def load_model():
|
||
"""加载YOLO目标检测模型"""
|
||
global _yolo_model
|
||
|
||
try:
|
||
_yolo_model = YOLO(model_path)
|
||
except Exception as e:
|
||
print(f"YOLO model load failed: {e}")
|
||
return False
|
||
|
||
return True if _yolo_model else False
|
||
|
||
|
||
def detect(frame, conf_threshold=0.2):
|
||
"""YOLO目标检测,返回(是否识别到, 结果字符串)"""
|
||
global _yolo_model
|
||
|
||
if not _yolo_model or frame is None:
|
||
return (False, "未初始化或无效帧")
|
||
|
||
try:
|
||
results = _yolo_model(frame, conf=conf_threshold)
|
||
# 检查是否有检测结果
|
||
has_results = len(results[0].boxes) > 0 if results else False
|
||
|
||
if not has_results:
|
||
return (False, "未检测到目标")
|
||
|
||
# 构建结果字符串
|
||
result_parts = []
|
||
for box in results[0].boxes:
|
||
cls = int(box.cls[0])
|
||
conf = float(box.conf[0])
|
||
bbox = [float(x) for x in box.xyxy[0]]
|
||
class_name = _yolo_model.names[cls] if hasattr(_yolo_model, 'names') else f"类别{cls}"
|
||
result_parts.append(f"{class_name} (置信度: {conf:.2f}, 边界框: {bbox})")
|
||
|
||
result_str = "; ".join(result_parts)
|
||
return (has_results, result_str)
|
||
|
||
except Exception as e:
|
||
print(f"YOLO detect error: {e}")
|
||
return (False, f"检测错误: {str(e)}") |