Files
VisDrone-Version/yolo_core.py

152 lines
5.2 KiB
Python
Raw Permalink Normal View History

2025-08-05 16:55:45 +08:00
from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
class YOLODetector:
def __init__(self, model_path='models/best.engine'):
# 加载 TensorRT 模型
self.model = YOLO(model_path, task="detect")
# 英文类别名称到中文的映射
self.class_name_mapping = {
'pedestrian': '行人',
'people': '人群',
'bicycle': '自行车',
'car': '轿车',
'van': '面包车',
'truck': '卡车',
'tricycle': '三轮车',
'awning-tricycle': '篷式三轮车',
'bus': '公交车',
'motor': '摩托车'
}
# 为每个类别设置固定的RGB颜色
self.color_mapping = {
'pedestrian': (71, 0, 36), # 勃艮第红
'people': (0, 255, 0), # 绿色
'bicycle': (0, 49, 83), # 普鲁士蓝
'car': (0, 47, 167), # 克莱茵蓝
'van': (128, 0, 128), # 紫色
'truck': (212, 72, 72), # 缇香红
'tricycle': (0, 49, 83), # 橙色
'awning-tricycle': (251, 220, 106), # 申布伦黄
'bus': (73, 45, 34), # 凡戴克棕
'motor': (1, 132, 127) # 马尔斯绿
}
# 初始化类别计数器
self.class_counts = {cls_name: 0 for cls_name in self.class_name_mapping.keys()}
# 初始化字体
try:
self.font = ImageFont.truetype("simhei.ttf", 20)
except IOError:
self.font = ImageFont.load_default()
def detect_and_draw_English(self, frame, conf=0.3, iou=0.5):
"""
对输入帧进行目标检测并返回绘制结果
Args:
frame: 输入的图像帧BGR格式
conf: 置信度阈值
iou: IOU阈值
Returns:
annotated_frame: 绘制了检测结果的图像帧
"""
try:
# 进行 YOLO 目标检测
results = self.model(
frame,
conf=conf,
iou=iou,
half=True,
)
result = results[0]
# 使用YOLO自带的绘制功能
annotated_frame = result.plot()
return annotated_frame
except Exception as e:
print(f"Detection error: {e}")
return frame
def detect_and_draw_Chinese(self, frame, conf=0.2, iou=0.3):
"""
对输入帧进行目标检测并绘制中文标注
Args:
frame: 输入的图像帧BGR格式
conf: 置信度阈值
iou: IOU阈值
Returns:
annotated_frame: 绘制了检测结果的图像帧
"""
try:
# 进行 YOLO 目标检测
results = self.model(
frame,
conf=conf,
iou=iou,
# half=True,
)
result = results[0]
# 获取原始帧的副本
img = frame.copy()
# 转换为PIL图像以绘制中文
pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_img)
# 绘制检测结果
for box in result.boxes:
# 获取边框坐标
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
# 获取类别ID和置信度
cls_id = int(box.cls[0].item())
conf = box.conf[0].item()
# 获取类别名称并转换为中文
cls_name = result.names[cls_id]
chinese_name = self.class_name_mapping.get(cls_name, cls_name)
# 获取该类别的颜色
color = self.color_mapping.get(cls_name, (255, 255, 255))
# 绘制边框
draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=3)
# 准备标签文本
text = f"{chinese_name} {conf:.2f}"
text_size = draw.textbbox((0, 0), text, font=self.font)
text_width = text_size[2] - text_size[0]
text_height = text_size[3] - text_size[1]
# 绘制标签背景(使用与边框相同的颜色)
draw.rectangle(
[(x1, y1 - text_height - 4), (x1 + text_width, y1)],
fill=color
)
# 绘制白色文本
draw.text(
(x1, y1 - text_height - 2),
text,
fill=(255, 255, 255), # 白色文本
font=self.font
)
# 转换回OpenCV格式
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
except Exception as e:
print(f"Detection error: {e}")
return frame