ocr1.0
This commit is contained in:
48
ocr/yolo_violation_detector.py
Normal file
48
ocr/yolo_violation_detector.py
Normal file
@ -0,0 +1,48 @@
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
from logger_config import logger
|
||||
|
||||
class ViolationDetector:
|
||||
"""
|
||||
用于加载YOLOv8 .pt模型并进行违规内容检测的类。
|
||||
"""
|
||||
def __init__(self, model_path):
|
||||
"""
|
||||
初始化检测器。
|
||||
|
||||
Args:
|
||||
model_path (str): YOLO .pt模型的路径。
|
||||
"""
|
||||
logger.info(f"正在从 '{model_path}' 加载YOLO模型...")
|
||||
self.model = YOLO(model_path)
|
||||
logger.info("YOLO模型加载成功。")
|
||||
|
||||
def detect(self, frame):
|
||||
"""
|
||||
对单帧图像进行目标检测。
|
||||
|
||||
Args:
|
||||
frame: 输入的图像帧 (NumPy数组, BGR格式)。
|
||||
|
||||
Returns:
|
||||
ultralytics.engine.results.Results: YOLO的检测结果对象。
|
||||
"""
|
||||
# conf可以根据您的模型效果进行调整
|
||||
# --- 为了测试,我们暂时将置信度调低,例如 0.2 ---
|
||||
results = self.model(frame, conf=0.2)
|
||||
return results[0]
|
||||
|
||||
def draw_boxes(self, frame, result):
|
||||
"""
|
||||
在图像帧上绘制检测框。
|
||||
|
||||
Args:
|
||||
frame: 原始图像帧。
|
||||
result: YOLO的检测结果对象。
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: 绘制了检测框的图像帧。
|
||||
"""
|
||||
# 使用YOLO自带的plot功能,方便快捷
|
||||
annotated_frame = result.plot()
|
||||
return annotated_frame
|
Reference in New Issue
Block a user