49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
|
from ultralytics import YOLO
|
|||
|
import cv2
|
|||
|
from logger_config import logger
|
|||
|
|
|||
|
class ViolationDetector:
|
|||
|
"""
|
|||
|
用于加载YOLOv8 .pt模型并进行违规内容检测的类。
|
|||
|
"""
|
|||
|
def __init__(self, model_path):
|
|||
|
"""
|
|||
|
初始化检测器。
|
|||
|
|
|||
|
Args:
|
|||
|
model_path (str): YOLO .pt模型的路径。
|
|||
|
"""
|
|||
|
logger.info(f"正在从 '{model_path}' 加载YOLO模型...")
|
|||
|
self.model = YOLO(model_path)
|
|||
|
logger.info("YOLO模型加载成功。")
|
|||
|
|
|||
|
def detect(self, frame):
|
|||
|
"""
|
|||
|
对单帧图像进行目标检测。
|
|||
|
|
|||
|
Args:
|
|||
|
frame: 输入的图像帧 (NumPy数组, BGR格式)。
|
|||
|
|
|||
|
Returns:
|
|||
|
ultralytics.engine.results.Results: YOLO的检测结果对象。
|
|||
|
"""
|
|||
|
# conf可以根据您的模型效果进行调整
|
|||
|
# --- 为了测试,我们暂时将置信度调低,例如 0.2 ---
|
|||
|
results = self.model(frame, conf=0.2)
|
|||
|
return results[0]
|
|||
|
|
|||
|
def draw_boxes(self, frame, result):
|
|||
|
"""
|
|||
|
在图像帧上绘制检测框。
|
|||
|
|
|||
|
Args:
|
|||
|
frame: 原始图像帧。
|
|||
|
result: YOLO的检测结果对象。
|
|||
|
|
|||
|
Returns:
|
|||
|
numpy.ndarray: 绘制了检测框的图像帧。
|
|||
|
"""
|
|||
|
# 使用YOLO自带的plot功能,方便快捷
|
|||
|
annotated_frame = result.plot()
|
|||
|
return annotated_frame
|