参数分离完全版
This commit is contained in:
151
yolo_core.py
Normal file
151
yolo_core.py
Normal file
@ -0,0 +1,151 @@
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import tensorrt as trt
|
||||
import pycuda.driver as cuda
|
||||
import pycuda.autoinit
|
||||
|
||||
class YOLODetector:
|
||||
def __init__(self, model_path='models/best.engine'):
|
||||
# 加载 TensorRT 模型
|
||||
self.model = YOLO(model_path, task="detect")
|
||||
# 英文类别名称到中文的映射
|
||||
self.class_name_mapping = {
|
||||
'pedestrian': '行人',
|
||||
'people': '人群',
|
||||
'bicycle': '自行车',
|
||||
'car': '轿车',
|
||||
'van': '面包车',
|
||||
'truck': '卡车',
|
||||
'tricycle': '三轮车',
|
||||
'awning-tricycle': '篷式三轮车',
|
||||
'bus': '公交车',
|
||||
'motor': '摩托车'
|
||||
}
|
||||
# 为每个类别设置固定的RGB颜色
|
||||
self.color_mapping = {
|
||||
'pedestrian': (71, 0, 36), # 勃艮第红
|
||||
'people': (0, 255, 0), # 绿色
|
||||
'bicycle': (0, 49, 83), # 普鲁士蓝
|
||||
'car': (0, 47, 167), # 克莱茵蓝
|
||||
'van': (128, 0, 128), # 紫色
|
||||
'truck': (212, 72, 72), # 缇香红
|
||||
'tricycle': (0, 49, 83), # 橙色
|
||||
'awning-tricycle': (251, 220, 106), # 申布伦黄
|
||||
'bus': (73, 45, 34), # 凡戴克棕
|
||||
'motor': (1, 132, 127) # 马尔斯绿
|
||||
}
|
||||
# 初始化类别计数器
|
||||
self.class_counts = {cls_name: 0 for cls_name in self.class_name_mapping.keys()}
|
||||
# 初始化字体
|
||||
try:
|
||||
self.font = ImageFont.truetype("simhei.ttf", 20)
|
||||
except IOError:
|
||||
self.font = ImageFont.load_default()
|
||||
|
||||
def detect_and_draw_English(self, frame, conf=0.3, iou=0.5):
|
||||
"""
|
||||
对输入帧进行目标检测并返回绘制结果
|
||||
|
||||
Args:
|
||||
frame: 输入的图像帧(BGR格式)
|
||||
conf: 置信度阈值
|
||||
iou: IOU阈值
|
||||
|
||||
Returns:
|
||||
annotated_frame: 绘制了检测结果的图像帧
|
||||
"""
|
||||
try:
|
||||
# 进行 YOLO 目标检测
|
||||
results = self.model(
|
||||
frame,
|
||||
conf=conf,
|
||||
iou=iou,
|
||||
half=True,
|
||||
)
|
||||
result = results[0]
|
||||
|
||||
# 使用YOLO自带的绘制功能
|
||||
annotated_frame = result.plot()
|
||||
|
||||
return annotated_frame
|
||||
|
||||
except Exception as e:
|
||||
print(f"Detection error: {e}")
|
||||
return frame
|
||||
|
||||
def detect_and_draw_Chinese(self, frame, conf=0.2, iou=0.3):
|
||||
"""
|
||||
对输入帧进行目标检测并绘制中文标注
|
||||
|
||||
Args:
|
||||
frame: 输入的图像帧(BGR格式)
|
||||
conf: 置信度阈值
|
||||
iou: IOU阈值
|
||||
|
||||
Returns:
|
||||
annotated_frame: 绘制了检测结果的图像帧
|
||||
"""
|
||||
try:
|
||||
# 进行 YOLO 目标检测
|
||||
results = self.model(
|
||||
frame,
|
||||
conf=conf,
|
||||
iou=iou,
|
||||
# half=True,
|
||||
)
|
||||
result = results[0]
|
||||
|
||||
# 获取原始帧的副本
|
||||
img = frame.copy()
|
||||
|
||||
# 转换为PIL图像以绘制中文
|
||||
pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||
draw = ImageDraw.Draw(pil_img)
|
||||
|
||||
# 绘制检测结果
|
||||
for box in result.boxes:
|
||||
# 获取边框坐标
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
|
||||
|
||||
# 获取类别ID和置信度
|
||||
cls_id = int(box.cls[0].item())
|
||||
conf = box.conf[0].item()
|
||||
|
||||
# 获取类别名称并转换为中文
|
||||
cls_name = result.names[cls_id]
|
||||
chinese_name = self.class_name_mapping.get(cls_name, cls_name)
|
||||
|
||||
# 获取该类别的颜色
|
||||
color = self.color_mapping.get(cls_name, (255, 255, 255))
|
||||
|
||||
# 绘制边框
|
||||
draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=3)
|
||||
|
||||
# 准备标签文本
|
||||
text = f"{chinese_name} {conf:.2f}"
|
||||
text_size = draw.textbbox((0, 0), text, font=self.font)
|
||||
text_width = text_size[2] - text_size[0]
|
||||
text_height = text_size[3] - text_size[1]
|
||||
|
||||
# 绘制标签背景(使用与边框相同的颜色)
|
||||
draw.rectangle(
|
||||
[(x1, y1 - text_height - 4), (x1 + text_width, y1)],
|
||||
fill=color
|
||||
)
|
||||
|
||||
# 绘制白色文本
|
||||
draw.text(
|
||||
(x1, y1 - text_height - 2),
|
||||
text,
|
||||
fill=(255, 255, 255), # 白色文本
|
||||
font=self.font
|
||||
)
|
||||
|
||||
# 转换回OpenCV格式
|
||||
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Detection error: {e}")
|
||||
return frame
|
Reference in New Issue
Block a user