This commit is contained in:
2025-09-03 14:38:42 +08:00
parent eb5cf715ec
commit b7773f5f00
19 changed files with 546 additions and 168 deletions

View File

@ -0,0 +1,48 @@
from ultralytics import YOLO
import cv2
from logger_config import logger
class ViolationDetector:
"""
用于加载YOLOv8 .pt模型并进行违规内容检测的类。
"""
def __init__(self, model_path):
"""
初始化检测器。
Args:
model_path (str): YOLO .pt模型的路径。
"""
logger.info(f"正在从 '{model_path}' 加载YOLO模型...")
self.model = YOLO(model_path)
logger.info("YOLO模型加载成功。")
def detect(self, frame):
"""
对单帧图像进行目标检测。
Args:
frame: 输入的图像帧 (NumPy数组, BGR格式)。
Returns:
ultralytics.engine.results.Results: YOLO的检测结果对象。
"""
# conf可以根据您的模型效果进行调整
# --- 为了测试,我们暂时将置信度调低,例如 0.2 ---
results = self.model(frame, conf=0.2)
return results[0]
def draw_boxes(self, frame, result):
"""
在图像帧上绘制检测框。
Args:
frame: 原始图像帧。
result: YOLO的检测结果对象。
Returns:
numpy.ndarray: 绘制了检测框的图像帧。
"""
# 使用YOLO自带的plot功能方便快捷
annotated_frame = result.plot()
return annotated_frame