参数分离完全版

This commit is contained in:
2025-08-05 16:55:45 +08:00
commit 25a2ded11d
16 changed files with 2013 additions and 0 deletions

254
rfdetr_core.py Normal file
View File

@ -0,0 +1,254 @@
import cv2
import supervision as sv
from rfdetr import RFDETRBase
from collections import defaultdict
from typing import Dict, Set
from PIL import Image, ImageDraw, ImageFont # 导入PIL库
import numpy as np # 导入numpy用于图像格式转换
import json # 新增
import os # 新增
class RFDETRDetector:
def __init__(self, config_name: str, base_model_dir="models", base_config_dir="configs", default_font_path="./font/MSYH.TTC", default_font_size=15):
self.config_path = os.path.join(base_config_dir, f"{config_name}.json")
if not os.path.exists(self.config_path):
raise FileNotFoundError(f"配置文件不存在: {self.config_path}")
with open(self.config_path, 'r', encoding='utf-8') as f:
self.config = json.load(f)
model_path = os.path.join(base_model_dir, self.config['model_pth_filename'])
resolution = self.config['resolution']
# 从配置读取字体路径和大小,如果未提供则使用默认值
font_path = self.config.get('font_path', default_font_path)
font_size = self.config.get('font_size', default_font_size)
# 1. 初始化模型
self.model = RFDETRBase(
pretrain_weights=model_path,
# pretrain_weights=model_path or r"E:\A\rf-detr-main\output\pre-train1\checkpoint_best_ema.pth",
resolution=resolution
)
# 2. 初始化跟踪器
self.tracker = sv.ByteTrack(
track_activation_threshold=self.config['tracker_activation_threshold'],
lost_track_buffer=self.config['tracker_lost_buffer'],
minimum_matching_threshold=self.config['tracker_match_threshold'],
minimum_consecutive_frames=self.config['tracker_consecutive_frames'],
frame_rate=self.config['tracker_frame_rate']
)
# 3. 类别定义
self.VISDRONE_CLASSES = self.config['classes_en']
self.VISDRONE_CLASSES_CHINESE = self.config['classes_zh_map']
# 新增:加载类别启用配置
self.detection_settings = self.config.get('detection_settings', {})
self.enabled_classes_filter = self.detection_settings.get('enabled_classes', {})
# 构建一个查找表对于未在filter中指定的类别默认为 True (启用)
self._active_classes_lookup = {
cls_name: self.enabled_classes_filter.get(cls_name, True)
for cls_name in self.VISDRONE_CLASSES
}
print(f"活动类别配置: {self._active_classes_lookup}")
# 4. 初始化字体
self.FONT_SIZE = font_size
try:
self.font = ImageFont.truetype(font_path, self.FONT_SIZE)
except IOError:
print(f"错误:无法加载字体 {font_path}。将使用默认字体。")
self.font = ImageFont.load_default() # 使用真正通用的默认字体
# 5. 类别计数器 (作为类属性)
self.class_tracks: Dict[str, Set[int]] = defaultdict(set)
self.category_counts: Dict[str, int] = defaultdict(int)
# 6. 初始化标注器
# 从配置加载默认颜色,如果失败则使用预设颜色
self.default_color_hex = self.config.get('default_color_hex', "#00FF00") # 默认绿色
self.bounding_box_thickness = self.config.get('bounding_box_thickness', 2)
# 加载颜色配置,用于 PIL 绘制
self.class_colors_hex = self.config.get('class_colors_hex', {})
self.last_annotated_frame: np.ndarray | None = None # 新增: 用于存储最新的标注帧
def _hex_to_rgb(self, hex_color: str) -> tuple:
"""将十六进制颜色字符串转换为RGB元组。"""
hex_color = hex_color.lstrip('#')
try:
return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
except ValueError:
print(f"警告: 无法解析十六进制颜色 '{hex_color}', 将使用默认颜色。")
# 解析失败时返回一个默认颜色,例如红色
return self._hex_to_rgb(self.default_color_hex if self.default_color_hex != hex_color else "#00FF00")
def _update_counter(self, detections: sv.Detections):
"""更新类别计数器"""
# 只统计有 tracker_id 的检测结果
valid_indices = detections.tracker_id != None
if not np.any(valid_indices): # 处理 detections 为空或 tracker_id 都为 None 的情况
return
class_ids = detections.class_id[valid_indices]
track_ids = detections.tracker_id[valid_indices]
for class_id, track_id in zip(class_ids, track_ids):
if track_id is None: # 跳过没有 tracker_id 的项
continue
# 使用英文类别名作为内部 key
class_name = self.VISDRONE_CLASSES[class_id]
if track_id not in self.class_tracks[class_name]:
self.class_tracks[class_name].add(track_id)
self.category_counts[class_name] += 1
def _draw_frame(self, frame: np.ndarray, detections: sv.Detections) -> np.ndarray:
"""使用PIL绘制检测框、中文标签和计数信息"""
pil_image = Image.fromarray(cv2.cvtColor(frame.copy(), cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_image)
# --- 使用 PIL 绘制检测框和中文标签 ---
valid_indices = detections.tracker_id != None # 或直接使用 detections.xyxy 如果不过滤无 tracker_id 的
if np.any(valid_indices):
boxes = detections.xyxy[valid_indices]
class_ids = detections.class_id[valid_indices]
# tracker_ids = detections.tracker_id[valid_indices] # 如果需要 tracker_id
for box, class_id in zip(boxes, class_ids):
x1, y1, x2, y2 = map(int, box)
english_label = self.VISDRONE_CLASSES[class_id]
chinese_label = self.VISDRONE_CLASSES_CHINESE.get(english_label, english_label)
# 获取边界框颜色
box_color_hex = self.class_colors_hex.get(english_label, self.default_color_hex)
box_rgb_color = self._hex_to_rgb(box_color_hex)
# 绘制边界框
draw.rectangle([x1, y1, x2, y2], outline=box_rgb_color, width=self.bounding_box_thickness)
# 绘制中文标签 (与之前逻辑类似)
text_to_draw = f"{chinese_label}"
# 标签背景 (可选,使其更易读)
# label_text_bbox = draw.textbbox((0,0), text_to_draw, font=self.font)
# label_width = label_text_bbox[2] - label_text_bbox[0]
# label_height = label_text_bbox[3] - label_text_bbox[1]
# label_bg_y1 = y1 - label_height - 4 if y1 - label_height - 4 > 0 else y1 + 2
# draw.rectangle([x1, label_bg_y1, x1 + label_width + 4, label_bg_y1 + label_height + 2], fill=box_rgb_color)
# text_color = (255,255,255) if sum(box_rgb_color) < 382 else (0,0,0) # 简易对比色
text_color = (255, 255, 255) # 白色 (RGB)
text_x = x1 + 2 # 稍微偏移,避免紧贴边框
text_y = y1 - self.FONT_SIZE - 2
if text_y < 0: # 如果标签超出图像顶部
text_y = y1 + 2
draw.text((text_x, text_y), text_to_draw, font=self.font, fill=text_color)
# --- 绘制统计面板 (右上角) ---
stats_text_lines = [
f"{self.VISDRONE_CLASSES_CHINESE.get(cls, cls)}: {self.category_counts[cls]}"
for cls in self.VISDRONE_CLASSES if self.category_counts[cls] > 0
]
frame_height, frame_width, _ = frame.shape
stats_start_x = frame_width - self.config.get('stats_panel_width', 200)
stats_start_y = self.config.get('stats_panel_margin_y', 10)
line_height = self.FONT_SIZE + self.config.get('stats_line_spacing', 5)
stats_text_color_hex = self.config.get('stats_text_color_hex', "#FFFFFF")
stats_text_color = self._hex_to_rgb(stats_text_color_hex)
# 可选:为统计面板添加背景
if stats_text_lines:
panel_height = len(stats_text_lines) * line_height + 10
panel_y2 = stats_start_y + panel_height
# 半透明背景
# overlay = Image.new('RGBA', pil_image.size, (0,0,0,0))
# panel_draw = ImageDraw.Draw(overlay)
# panel_draw.rectangle(
# [stats_start_x - 5, stats_start_y - 5, frame_width - 5, panel_y2],
# fill=(100, 100, 100, 128) # 半透明灰色
# )
# pil_image = Image.alpha_composite(pil_image.convert('RGBA'), overlay)
# draw = ImageDraw.Draw(pil_image) # 如果用了 alpha_composite, 需要重新获取 draw 对象
# 或者简单不透明背景
# draw.rectangle(
# [stats_start_x - 5, stats_start_y - 5, frame_width - 5, panel_y2],
# fill=self._hex_to_rgb(self.config.get('stats_panel_bg_color_hex', "#808080")) # 例如灰色背景
# )
for i, line in enumerate(stats_text_lines):
text_pos = (stats_start_x, stats_start_y + i * line_height)
draw.text(text_pos, line, font=self.font, fill=stats_text_color)
final_annotated_frame = cv2.cvtColor(np.array(pil_image.convert('RGB')), cv2.COLOR_RGB2BGR)
return final_annotated_frame
def detect_and_draw_count(self, frame: np.ndarray, conf: float = -1.0) -> np.ndarray:
"""执行单帧检测、跟踪、计数并绘制结果(包含类别过滤)。"""
if conf == -1.0:
# 优先从 detection_settings 中获取其次是顶层config最后是硬编码默认值
effective_conf = float(
self.detection_settings.get('default_confidence_threshold',
self.config.get('default_confidence_threshold', 0.8))
)
else:
effective_conf = conf
try:
# 1. 执行检测
detections = self.model.predict(frame, threshold=effective_conf)
# 处理 detections 为 None 或空的情况
if detections is None or len(detections) == 0:
detections = sv.Detections.empty()
annotated_frame = self._draw_frame(frame, detections)
self.last_annotated_frame = annotated_frame.copy() # 新增
return annotated_frame
# 新增:根据配置过滤检测到的类别
if detections is not None and len(detections) > 0:
keep_indices = []
for i, class_id in enumerate(detections.class_id):
if class_id < len(self.VISDRONE_CLASSES): # 确保 class_id 有效
class_name = self.VISDRONE_CLASSES[class_id]
if self._active_classes_lookup.get(class_name, True): # 默认为 True
keep_indices.append(i)
else:
print(f"警告: 检测到无效的 class_id {class_id},超出了已知类别范围。")
if not keep_indices:
detections = sv.Detections.empty()
else:
detections = detections[keep_indices]
# 如果过滤后没有检测结果
if len(detections) == 0:
annotated_frame = self._draw_frame(frame, sv.Detections.empty())
self.last_annotated_frame = annotated_frame.copy() # 新增
return annotated_frame
# 2. 执行跟踪 (只对过滤后的结果进行跟踪)
detections = self.tracker.update_with_detections(detections)
# 3. 更新计数器 (只对过滤并跟踪后的结果进行计数)
self._update_counter(detections)
# 4. 绘制结果
annotated_frame = self._draw_frame(frame, detections)
self.last_annotated_frame = annotated_frame.copy() # 新增
return annotated_frame
except Exception as e:
print(f"处理帧时发生错误: {e}")
if frame is not None:
self.last_annotated_frame = frame.copy() # 新增
else:
self.last_annotated_frame = None # 新增
return frame