Files
VisDrone-Version/yolo_core.py
2025-08-05 16:55:45 +08:00

152 lines
5.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
class YOLODetector:
def __init__(self, model_path='models/best.engine'):
# 加载 TensorRT 模型
self.model = YOLO(model_path, task="detect")
# 英文类别名称到中文的映射
self.class_name_mapping = {
'pedestrian': '行人',
'people': '人群',
'bicycle': '自行车',
'car': '轿车',
'van': '面包车',
'truck': '卡车',
'tricycle': '三轮车',
'awning-tricycle': '篷式三轮车',
'bus': '公交车',
'motor': '摩托车'
}
# 为每个类别设置固定的RGB颜色
self.color_mapping = {
'pedestrian': (71, 0, 36), # 勃艮第红
'people': (0, 255, 0), # 绿色
'bicycle': (0, 49, 83), # 普鲁士蓝
'car': (0, 47, 167), # 克莱茵蓝
'van': (128, 0, 128), # 紫色
'truck': (212, 72, 72), # 缇香红
'tricycle': (0, 49, 83), # 橙色
'awning-tricycle': (251, 220, 106), # 申布伦黄
'bus': (73, 45, 34), # 凡戴克棕
'motor': (1, 132, 127) # 马尔斯绿
}
# 初始化类别计数器
self.class_counts = {cls_name: 0 for cls_name in self.class_name_mapping.keys()}
# 初始化字体
try:
self.font = ImageFont.truetype("simhei.ttf", 20)
except IOError:
self.font = ImageFont.load_default()
def detect_and_draw_English(self, frame, conf=0.3, iou=0.5):
"""
对输入帧进行目标检测并返回绘制结果
Args:
frame: 输入的图像帧BGR格式
conf: 置信度阈值
iou: IOU阈值
Returns:
annotated_frame: 绘制了检测结果的图像帧
"""
try:
# 进行 YOLO 目标检测
results = self.model(
frame,
conf=conf,
iou=iou,
half=True,
)
result = results[0]
# 使用YOLO自带的绘制功能
annotated_frame = result.plot()
return annotated_frame
except Exception as e:
print(f"Detection error: {e}")
return frame
def detect_and_draw_Chinese(self, frame, conf=0.2, iou=0.3):
"""
对输入帧进行目标检测并绘制中文标注
Args:
frame: 输入的图像帧BGR格式
conf: 置信度阈值
iou: IOU阈值
Returns:
annotated_frame: 绘制了检测结果的图像帧
"""
try:
# 进行 YOLO 目标检测
results = self.model(
frame,
conf=conf,
iou=iou,
# half=True,
)
result = results[0]
# 获取原始帧的副本
img = frame.copy()
# 转换为PIL图像以绘制中文
pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_img)
# 绘制检测结果
for box in result.boxes:
# 获取边框坐标
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
# 获取类别ID和置信度
cls_id = int(box.cls[0].item())
conf = box.conf[0].item()
# 获取类别名称并转换为中文
cls_name = result.names[cls_id]
chinese_name = self.class_name_mapping.get(cls_name, cls_name)
# 获取该类别的颜色
color = self.color_mapping.get(cls_name, (255, 255, 255))
# 绘制边框
draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=3)
# 准备标签文本
text = f"{chinese_name} {conf:.2f}"
text_size = draw.textbbox((0, 0), text, font=self.font)
text_width = text_size[2] - text_size[0]
text_height = text_size[3] - text_size[1]
# 绘制标签背景(使用与边框相同的颜色)
draw.rectangle(
[(x1, y1 - text_height - 4), (x1 + text_width, y1)],
fill=color
)
# 绘制白色文本
draw.text(
(x1, y1 - text_height - 2),
text,
fill=(255, 255, 255), # 白色文本
font=self.font
)
# 转换回OpenCV格式
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
except Exception as e:
print(f"Detection error: {e}")
return frame