Files
VisDrone-Version/rfdetr_core.py
2025-08-05 16:55:45 +08:00

254 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import supervision as sv
from rfdetr import RFDETRBase
from collections import defaultdict
from typing import Dict, Set
from PIL import Image, ImageDraw, ImageFont # 导入PIL库
import numpy as np # 导入numpy用于图像格式转换
import json # 新增
import os # 新增
class RFDETRDetector:
def __init__(self, config_name: str, base_model_dir="models", base_config_dir="configs", default_font_path="./font/MSYH.TTC", default_font_size=15):
self.config_path = os.path.join(base_config_dir, f"{config_name}.json")
if not os.path.exists(self.config_path):
raise FileNotFoundError(f"配置文件不存在: {self.config_path}")
with open(self.config_path, 'r', encoding='utf-8') as f:
self.config = json.load(f)
model_path = os.path.join(base_model_dir, self.config['model_pth_filename'])
resolution = self.config['resolution']
# 从配置读取字体路径和大小,如果未提供则使用默认值
font_path = self.config.get('font_path', default_font_path)
font_size = self.config.get('font_size', default_font_size)
# 1. 初始化模型
self.model = RFDETRBase(
pretrain_weights=model_path,
# pretrain_weights=model_path or r"E:\A\rf-detr-main\output\pre-train1\checkpoint_best_ema.pth",
resolution=resolution
)
# 2. 初始化跟踪器
self.tracker = sv.ByteTrack(
track_activation_threshold=self.config['tracker_activation_threshold'],
lost_track_buffer=self.config['tracker_lost_buffer'],
minimum_matching_threshold=self.config['tracker_match_threshold'],
minimum_consecutive_frames=self.config['tracker_consecutive_frames'],
frame_rate=self.config['tracker_frame_rate']
)
# 3. 类别定义
self.VISDRONE_CLASSES = self.config['classes_en']
self.VISDRONE_CLASSES_CHINESE = self.config['classes_zh_map']
# 新增:加载类别启用配置
self.detection_settings = self.config.get('detection_settings', {})
self.enabled_classes_filter = self.detection_settings.get('enabled_classes', {})
# 构建一个查找表对于未在filter中指定的类别默认为 True (启用)
self._active_classes_lookup = {
cls_name: self.enabled_classes_filter.get(cls_name, True)
for cls_name in self.VISDRONE_CLASSES
}
print(f"活动类别配置: {self._active_classes_lookup}")
# 4. 初始化字体
self.FONT_SIZE = font_size
try:
self.font = ImageFont.truetype(font_path, self.FONT_SIZE)
except IOError:
print(f"错误:无法加载字体 {font_path}。将使用默认字体。")
self.font = ImageFont.load_default() # 使用真正通用的默认字体
# 5. 类别计数器 (作为类属性)
self.class_tracks: Dict[str, Set[int]] = defaultdict(set)
self.category_counts: Dict[str, int] = defaultdict(int)
# 6. 初始化标注器
# 从配置加载默认颜色,如果失败则使用预设颜色
self.default_color_hex = self.config.get('default_color_hex', "#00FF00") # 默认绿色
self.bounding_box_thickness = self.config.get('bounding_box_thickness', 2)
# 加载颜色配置,用于 PIL 绘制
self.class_colors_hex = self.config.get('class_colors_hex', {})
self.last_annotated_frame: np.ndarray | None = None # 新增: 用于存储最新的标注帧
def _hex_to_rgb(self, hex_color: str) -> tuple:
"""将十六进制颜色字符串转换为RGB元组。"""
hex_color = hex_color.lstrip('#')
try:
return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
except ValueError:
print(f"警告: 无法解析十六进制颜色 '{hex_color}', 将使用默认颜色。")
# 解析失败时返回一个默认颜色,例如红色
return self._hex_to_rgb(self.default_color_hex if self.default_color_hex != hex_color else "#00FF00")
def _update_counter(self, detections: sv.Detections):
"""更新类别计数器"""
# 只统计有 tracker_id 的检测结果
valid_indices = detections.tracker_id != None
if not np.any(valid_indices): # 处理 detections 为空或 tracker_id 都为 None 的情况
return
class_ids = detections.class_id[valid_indices]
track_ids = detections.tracker_id[valid_indices]
for class_id, track_id in zip(class_ids, track_ids):
if track_id is None: # 跳过没有 tracker_id 的项
continue
# 使用英文类别名作为内部 key
class_name = self.VISDRONE_CLASSES[class_id]
if track_id not in self.class_tracks[class_name]:
self.class_tracks[class_name].add(track_id)
self.category_counts[class_name] += 1
def _draw_frame(self, frame: np.ndarray, detections: sv.Detections) -> np.ndarray:
"""使用PIL绘制检测框、中文标签和计数信息"""
pil_image = Image.fromarray(cv2.cvtColor(frame.copy(), cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_image)
# --- 使用 PIL 绘制检测框和中文标签 ---
valid_indices = detections.tracker_id != None # 或直接使用 detections.xyxy 如果不过滤无 tracker_id 的
if np.any(valid_indices):
boxes = detections.xyxy[valid_indices]
class_ids = detections.class_id[valid_indices]
# tracker_ids = detections.tracker_id[valid_indices] # 如果需要 tracker_id
for box, class_id in zip(boxes, class_ids):
x1, y1, x2, y2 = map(int, box)
english_label = self.VISDRONE_CLASSES[class_id]
chinese_label = self.VISDRONE_CLASSES_CHINESE.get(english_label, english_label)
# 获取边界框颜色
box_color_hex = self.class_colors_hex.get(english_label, self.default_color_hex)
box_rgb_color = self._hex_to_rgb(box_color_hex)
# 绘制边界框
draw.rectangle([x1, y1, x2, y2], outline=box_rgb_color, width=self.bounding_box_thickness)
# 绘制中文标签 (与之前逻辑类似)
text_to_draw = f"{chinese_label}"
# 标签背景 (可选,使其更易读)
# label_text_bbox = draw.textbbox((0,0), text_to_draw, font=self.font)
# label_width = label_text_bbox[2] - label_text_bbox[0]
# label_height = label_text_bbox[3] - label_text_bbox[1]
# label_bg_y1 = y1 - label_height - 4 if y1 - label_height - 4 > 0 else y1 + 2
# draw.rectangle([x1, label_bg_y1, x1 + label_width + 4, label_bg_y1 + label_height + 2], fill=box_rgb_color)
# text_color = (255,255,255) if sum(box_rgb_color) < 382 else (0,0,0) # 简易对比色
text_color = (255, 255, 255) # 白色 (RGB)
text_x = x1 + 2 # 稍微偏移,避免紧贴边框
text_y = y1 - self.FONT_SIZE - 2
if text_y < 0: # 如果标签超出图像顶部
text_y = y1 + 2
draw.text((text_x, text_y), text_to_draw, font=self.font, fill=text_color)
# --- 绘制统计面板 (右上角) ---
stats_text_lines = [
f"{self.VISDRONE_CLASSES_CHINESE.get(cls, cls)}: {self.category_counts[cls]}"
for cls in self.VISDRONE_CLASSES if self.category_counts[cls] > 0
]
frame_height, frame_width, _ = frame.shape
stats_start_x = frame_width - self.config.get('stats_panel_width', 200)
stats_start_y = self.config.get('stats_panel_margin_y', 10)
line_height = self.FONT_SIZE + self.config.get('stats_line_spacing', 5)
stats_text_color_hex = self.config.get('stats_text_color_hex', "#FFFFFF")
stats_text_color = self._hex_to_rgb(stats_text_color_hex)
# 可选:为统计面板添加背景
if stats_text_lines:
panel_height = len(stats_text_lines) * line_height + 10
panel_y2 = stats_start_y + panel_height
# 半透明背景
# overlay = Image.new('RGBA', pil_image.size, (0,0,0,0))
# panel_draw = ImageDraw.Draw(overlay)
# panel_draw.rectangle(
# [stats_start_x - 5, stats_start_y - 5, frame_width - 5, panel_y2],
# fill=(100, 100, 100, 128) # 半透明灰色
# )
# pil_image = Image.alpha_composite(pil_image.convert('RGBA'), overlay)
# draw = ImageDraw.Draw(pil_image) # 如果用了 alpha_composite, 需要重新获取 draw 对象
# 或者简单不透明背景
# draw.rectangle(
# [stats_start_x - 5, stats_start_y - 5, frame_width - 5, panel_y2],
# fill=self._hex_to_rgb(self.config.get('stats_panel_bg_color_hex', "#808080")) # 例如灰色背景
# )
for i, line in enumerate(stats_text_lines):
text_pos = (stats_start_x, stats_start_y + i * line_height)
draw.text(text_pos, line, font=self.font, fill=stats_text_color)
final_annotated_frame = cv2.cvtColor(np.array(pil_image.convert('RGB')), cv2.COLOR_RGB2BGR)
return final_annotated_frame
def detect_and_draw_count(self, frame: np.ndarray, conf: float = -1.0) -> np.ndarray:
"""执行单帧检测、跟踪、计数并绘制结果(包含类别过滤)。"""
if conf == -1.0:
# 优先从 detection_settings 中获取其次是顶层config最后是硬编码默认值
effective_conf = float(
self.detection_settings.get('default_confidence_threshold',
self.config.get('default_confidence_threshold', 0.8))
)
else:
effective_conf = conf
try:
# 1. 执行检测
detections = self.model.predict(frame, threshold=effective_conf)
# 处理 detections 为 None 或空的情况
if detections is None or len(detections) == 0:
detections = sv.Detections.empty()
annotated_frame = self._draw_frame(frame, detections)
self.last_annotated_frame = annotated_frame.copy() # 新增
return annotated_frame
# 新增:根据配置过滤检测到的类别
if detections is not None and len(detections) > 0:
keep_indices = []
for i, class_id in enumerate(detections.class_id):
if class_id < len(self.VISDRONE_CLASSES): # 确保 class_id 有效
class_name = self.VISDRONE_CLASSES[class_id]
if self._active_classes_lookup.get(class_name, True): # 默认为 True
keep_indices.append(i)
else:
print(f"警告: 检测到无效的 class_id {class_id},超出了已知类别范围。")
if not keep_indices:
detections = sv.Detections.empty()
else:
detections = detections[keep_indices]
# 如果过滤后没有检测结果
if len(detections) == 0:
annotated_frame = self._draw_frame(frame, sv.Detections.empty())
self.last_annotated_frame = annotated_frame.copy() # 新增
return annotated_frame
# 2. 执行跟踪 (只对过滤后的结果进行跟踪)
detections = self.tracker.update_with_detections(detections)
# 3. 更新计数器 (只对过滤并跟踪后的结果进行计数)
self._update_counter(detections)
# 4. 绘制结果
annotated_frame = self._draw_frame(frame, detections)
self.last_annotated_frame = annotated_frame.copy() # 新增
return annotated_frame
except Exception as e:
print(f"处理帧时发生错误: {e}")
if frame is not None:
self.last_annotated_frame = frame.copy() # 新增
else:
self.last_annotated_frame = None # 新增
return frame