Files
video/ocr/yolo_violation_detector.py

47 lines
1.3 KiB
Python
Raw Normal View History

2025-09-03 14:38:42 +08:00
from ultralytics import YOLO
import cv2
class ViolationDetector:
"""
用于加载YOLOv8 .pt模型并进行违规内容检测的类
"""
def __init__(self, model_path):
"""
初始化检测器
Args:
model_path (str): YOLO .pt模型的路径
"""
2025-09-03 16:22:21 +08:00
print(f"正在从 '{model_path}' 加载YOLO模型...")
2025-09-03 14:38:42 +08:00
self.model = YOLO(model_path)
2025-09-03 16:22:21 +08:00
print("YOLO模型加载成功。")
2025-09-03 14:38:42 +08:00
def detect(self, frame):
"""
对单帧图像进行目标检测
Args:
frame: 输入的图像帧 (NumPy数组, BGR格式)
Returns:
ultralytics.engine.results.Results: YOLO的检测结果对象
"""
# conf可以根据您的模型效果进行调整
# --- 为了测试,我们暂时将置信度调低,例如 0.2 ---
results = self.model(frame, conf=0.2)
return results[0]
def draw_boxes(self, frame, result):
"""
在图像帧上绘制检测框
Args:
frame: 原始图像帧
result: YOLO的检测结果对象
Returns:
numpy.ndarray: 绘制了检测框的图像帧
"""
# 使用YOLO自带的plot功能方便快捷
annotated_frame = result.plot()
2025-09-03 16:22:21 +08:00
return annotated_frame