152 lines
5.2 KiB
Python
152 lines
5.2 KiB
Python
from ultralytics import YOLO
|
||
import cv2
|
||
import numpy as np
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
import tensorrt as trt
|
||
import pycuda.driver as cuda
|
||
import pycuda.autoinit
|
||
|
||
class YOLODetector:
|
||
def __init__(self, model_path='models/best.engine'):
|
||
# 加载 TensorRT 模型
|
||
self.model = YOLO(model_path, task="detect")
|
||
# 英文类别名称到中文的映射
|
||
self.class_name_mapping = {
|
||
'pedestrian': '行人',
|
||
'people': '人群',
|
||
'bicycle': '自行车',
|
||
'car': '轿车',
|
||
'van': '面包车',
|
||
'truck': '卡车',
|
||
'tricycle': '三轮车',
|
||
'awning-tricycle': '篷式三轮车',
|
||
'bus': '公交车',
|
||
'motor': '摩托车'
|
||
}
|
||
# 为每个类别设置固定的RGB颜色
|
||
self.color_mapping = {
|
||
'pedestrian': (71, 0, 36), # 勃艮第红
|
||
'people': (0, 255, 0), # 绿色
|
||
'bicycle': (0, 49, 83), # 普鲁士蓝
|
||
'car': (0, 47, 167), # 克莱茵蓝
|
||
'van': (128, 0, 128), # 紫色
|
||
'truck': (212, 72, 72), # 缇香红
|
||
'tricycle': (0, 49, 83), # 橙色
|
||
'awning-tricycle': (251, 220, 106), # 申布伦黄
|
||
'bus': (73, 45, 34), # 凡戴克棕
|
||
'motor': (1, 132, 127) # 马尔斯绿
|
||
}
|
||
# 初始化类别计数器
|
||
self.class_counts = {cls_name: 0 for cls_name in self.class_name_mapping.keys()}
|
||
# 初始化字体
|
||
try:
|
||
self.font = ImageFont.truetype("simhei.ttf", 20)
|
||
except IOError:
|
||
self.font = ImageFont.load_default()
|
||
|
||
def detect_and_draw_English(self, frame, conf=0.3, iou=0.5):
|
||
"""
|
||
对输入帧进行目标检测并返回绘制结果
|
||
|
||
Args:
|
||
frame: 输入的图像帧(BGR格式)
|
||
conf: 置信度阈值
|
||
iou: IOU阈值
|
||
|
||
Returns:
|
||
annotated_frame: 绘制了检测结果的图像帧
|
||
"""
|
||
try:
|
||
# 进行 YOLO 目标检测
|
||
results = self.model(
|
||
frame,
|
||
conf=conf,
|
||
iou=iou,
|
||
half=True,
|
||
)
|
||
result = results[0]
|
||
|
||
# 使用YOLO自带的绘制功能
|
||
annotated_frame = result.plot()
|
||
|
||
return annotated_frame
|
||
|
||
except Exception as e:
|
||
print(f"Detection error: {e}")
|
||
return frame
|
||
|
||
def detect_and_draw_Chinese(self, frame, conf=0.2, iou=0.3):
|
||
"""
|
||
对输入帧进行目标检测并绘制中文标注
|
||
|
||
Args:
|
||
frame: 输入的图像帧(BGR格式)
|
||
conf: 置信度阈值
|
||
iou: IOU阈值
|
||
|
||
Returns:
|
||
annotated_frame: 绘制了检测结果的图像帧
|
||
"""
|
||
try:
|
||
# 进行 YOLO 目标检测
|
||
results = self.model(
|
||
frame,
|
||
conf=conf,
|
||
iou=iou,
|
||
# half=True,
|
||
)
|
||
result = results[0]
|
||
|
||
# 获取原始帧的副本
|
||
img = frame.copy()
|
||
|
||
# 转换为PIL图像以绘制中文
|
||
pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||
draw = ImageDraw.Draw(pil_img)
|
||
|
||
# 绘制检测结果
|
||
for box in result.boxes:
|
||
# 获取边框坐标
|
||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
|
||
|
||
# 获取类别ID和置信度
|
||
cls_id = int(box.cls[0].item())
|
||
conf = box.conf[0].item()
|
||
|
||
# 获取类别名称并转换为中文
|
||
cls_name = result.names[cls_id]
|
||
chinese_name = self.class_name_mapping.get(cls_name, cls_name)
|
||
|
||
# 获取该类别的颜色
|
||
color = self.color_mapping.get(cls_name, (255, 255, 255))
|
||
|
||
# 绘制边框
|
||
draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=3)
|
||
|
||
# 准备标签文本
|
||
text = f"{chinese_name} {conf:.2f}"
|
||
text_size = draw.textbbox((0, 0), text, font=self.font)
|
||
text_width = text_size[2] - text_size[0]
|
||
text_height = text_size[3] - text_size[1]
|
||
|
||
# 绘制标签背景(使用与边框相同的颜色)
|
||
draw.rectangle(
|
||
[(x1, y1 - text_height - 4), (x1 + text_width, y1)],
|
||
fill=color
|
||
)
|
||
|
||
# 绘制白色文本
|
||
draw.text(
|
||
(x1, y1 - text_height - 2),
|
||
text,
|
||
fill=(255, 255, 255), # 白色文本
|
||
font=self.font
|
||
)
|
||
|
||
# 转换回OpenCV格式
|
||
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||
|
||
except Exception as e:
|
||
print(f"Detection error: {e}")
|
||
return frame
|