47 lines
1.3 KiB
Python
47 lines
1.3 KiB
Python
from ultralytics import YOLO
|
||
import cv2
|
||
|
||
class ViolationDetector:
|
||
"""
|
||
用于加载YOLOv8 .pt模型并进行违规内容检测的类。
|
||
"""
|
||
def __init__(self, model_path):
|
||
"""
|
||
初始化检测器。
|
||
|
||
Args:
|
||
model_path (str): YOLO .pt模型的路径。
|
||
"""
|
||
print(f"正在从 '{model_path}' 加载YOLO模型...")
|
||
self.model = YOLO(model_path)
|
||
print("YOLO模型加载成功。")
|
||
|
||
def detect(self, frame):
|
||
"""
|
||
对单帧图像进行目标检测。
|
||
|
||
Args:
|
||
frame: 输入的图像帧 (NumPy数组, BGR格式)。
|
||
|
||
Returns:
|
||
ultralytics.engine.results.Results: YOLO的检测结果对象。
|
||
"""
|
||
# conf可以根据您的模型效果进行调整
|
||
# --- 为了测试,我们暂时将置信度调低,例如 0.2 ---
|
||
results = self.model(frame, conf=0.2)
|
||
return results[0]
|
||
|
||
def draw_boxes(self, frame, result):
|
||
"""
|
||
在图像帧上绘制检测框。
|
||
|
||
Args:
|
||
frame: 原始图像帧。
|
||
result: YOLO的检测结果对象。
|
||
|
||
Returns:
|
||
numpy.ndarray: 绘制了检测框的图像帧。
|
||
"""
|
||
# 使用YOLO自带的plot功能,方便快捷
|
||
annotated_frame = result.plot()
|
||
return annotated_frame |