import cv2 import supervision as sv from rfdetr import RFDETRBase from collections import defaultdict from typing import Dict, Set from PIL import Image, ImageDraw, ImageFont # 导入PIL库 import numpy as np # 导入numpy用于图像格式转换 import json # 新增 import os # 新增 class RFDETRDetector: def __init__(self, config_name: str, base_model_dir="models", base_config_dir="configs", default_font_path="./font/MSYH.TTC", default_font_size=15): self.config_path = os.path.join(base_config_dir, f"{config_name}.json") if not os.path.exists(self.config_path): raise FileNotFoundError(f"配置文件不存在: {self.config_path}") with open(self.config_path, 'r', encoding='utf-8') as f: self.config = json.load(f) model_path = os.path.join(base_model_dir, self.config['model_pth_filename']) resolution = self.config['resolution'] # 从配置读取字体路径和大小,如果未提供则使用默认值 font_path = self.config.get('font_path', default_font_path) font_size = self.config.get('font_size', default_font_size) # 1. 初始化模型 self.model = RFDETRBase( pretrain_weights=model_path, # pretrain_weights=model_path or r"E:\A\rf-detr-main\output\pre-train1\checkpoint_best_ema.pth", resolution=resolution ) # 2. 初始化跟踪器 self.tracker = sv.ByteTrack( track_activation_threshold=self.config['tracker_activation_threshold'], lost_track_buffer=self.config['tracker_lost_buffer'], minimum_matching_threshold=self.config['tracker_match_threshold'], minimum_consecutive_frames=self.config['tracker_consecutive_frames'], frame_rate=self.config['tracker_frame_rate'] ) # 3. 类别定义 self.VISDRONE_CLASSES = self.config['classes_en'] self.VISDRONE_CLASSES_CHINESE = self.config['classes_zh_map'] # 新增:加载类别启用配置 self.detection_settings = self.config.get('detection_settings', {}) self.enabled_classes_filter = self.detection_settings.get('enabled_classes', {}) # 构建一个查找表,对于未在filter中指定的类别,默认为 True (启用) self._active_classes_lookup = { cls_name: self.enabled_classes_filter.get(cls_name, True) for cls_name in self.VISDRONE_CLASSES } print(f"活动类别配置: {self._active_classes_lookup}") # 4. 初始化字体 self.FONT_SIZE = font_size try: self.font = ImageFont.truetype(font_path, self.FONT_SIZE) except IOError: print(f"错误:无法加载字体 {font_path}。将使用默认字体。") self.font = ImageFont.load_default() # 使用真正通用的默认字体 # 5. 类别计数器 (作为类属性) self.class_tracks: Dict[str, Set[int]] = defaultdict(set) self.category_counts: Dict[str, int] = defaultdict(int) # 6. 初始化标注器 # 从配置加载默认颜色,如果失败则使用预设颜色 self.default_color_hex = self.config.get('default_color_hex', "#00FF00") # 默认绿色 self.bounding_box_thickness = self.config.get('bounding_box_thickness', 2) # 加载颜色配置,用于 PIL 绘制 self.class_colors_hex = self.config.get('class_colors_hex', {}) self.last_annotated_frame: np.ndarray | None = None # 新增: 用于存储最新的标注帧 def _hex_to_rgb(self, hex_color: str) -> tuple: """将十六进制颜色字符串转换为RGB元组。""" hex_color = hex_color.lstrip('#') try: return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)) except ValueError: print(f"警告: 无法解析十六进制颜色 '{hex_color}', 将使用默认颜色。") # 解析失败时返回一个默认颜色,例如红色 return self._hex_to_rgb(self.default_color_hex if self.default_color_hex != hex_color else "#00FF00") def _update_counter(self, detections: sv.Detections): """更新类别计数器""" # 只统计有 tracker_id 的检测结果 valid_indices = detections.tracker_id != None if not np.any(valid_indices): # 处理 detections 为空或 tracker_id 都为 None 的情况 return class_ids = detections.class_id[valid_indices] track_ids = detections.tracker_id[valid_indices] for class_id, track_id in zip(class_ids, track_ids): if track_id is None: # 跳过没有 tracker_id 的项 continue # 使用英文类别名作为内部 key class_name = self.VISDRONE_CLASSES[class_id] if track_id not in self.class_tracks[class_name]: self.class_tracks[class_name].add(track_id) self.category_counts[class_name] += 1 def _draw_frame(self, frame: np.ndarray, detections: sv.Detections) -> np.ndarray: """使用PIL绘制检测框、中文标签和计数信息""" pil_image = Image.fromarray(cv2.cvtColor(frame.copy(), cv2.COLOR_BGR2RGB)) draw = ImageDraw.Draw(pil_image) # --- 使用 PIL 绘制检测框和中文标签 --- valid_indices = detections.tracker_id != None # 或直接使用 detections.xyxy 如果不过滤无 tracker_id 的 if np.any(valid_indices): boxes = detections.xyxy[valid_indices] class_ids = detections.class_id[valid_indices] # tracker_ids = detections.tracker_id[valid_indices] # 如果需要 tracker_id for box, class_id in zip(boxes, class_ids): x1, y1, x2, y2 = map(int, box) english_label = self.VISDRONE_CLASSES[class_id] chinese_label = self.VISDRONE_CLASSES_CHINESE.get(english_label, english_label) # 获取边界框颜色 box_color_hex = self.class_colors_hex.get(english_label, self.default_color_hex) box_rgb_color = self._hex_to_rgb(box_color_hex) # 绘制边界框 draw.rectangle([x1, y1, x2, y2], outline=box_rgb_color, width=self.bounding_box_thickness) # 绘制中文标签 (与之前逻辑类似) text_to_draw = f"{chinese_label}" # 标签背景 (可选,使其更易读) # label_text_bbox = draw.textbbox((0,0), text_to_draw, font=self.font) # label_width = label_text_bbox[2] - label_text_bbox[0] # label_height = label_text_bbox[3] - label_text_bbox[1] # label_bg_y1 = y1 - label_height - 4 if y1 - label_height - 4 > 0 else y1 + 2 # draw.rectangle([x1, label_bg_y1, x1 + label_width + 4, label_bg_y1 + label_height + 2], fill=box_rgb_color) # text_color = (255,255,255) if sum(box_rgb_color) < 382 else (0,0,0) # 简易对比色 text_color = (255, 255, 255) # 白色 (RGB) text_x = x1 + 2 # 稍微偏移,避免紧贴边框 text_y = y1 - self.FONT_SIZE - 2 if text_y < 0: # 如果标签超出图像顶部 text_y = y1 + 2 draw.text((text_x, text_y), text_to_draw, font=self.font, fill=text_color) # --- 绘制统计面板 (右上角) --- stats_text_lines = [ f"{self.VISDRONE_CLASSES_CHINESE.get(cls, cls)}: {self.category_counts[cls]}" for cls in self.VISDRONE_CLASSES if self.category_counts[cls] > 0 ] frame_height, frame_width, _ = frame.shape stats_start_x = frame_width - self.config.get('stats_panel_width', 200) stats_start_y = self.config.get('stats_panel_margin_y', 10) line_height = self.FONT_SIZE + self.config.get('stats_line_spacing', 5) stats_text_color_hex = self.config.get('stats_text_color_hex', "#FFFFFF") stats_text_color = self._hex_to_rgb(stats_text_color_hex) # 可选:为统计面板添加背景 if stats_text_lines: panel_height = len(stats_text_lines) * line_height + 10 panel_y2 = stats_start_y + panel_height # 半透明背景 # overlay = Image.new('RGBA', pil_image.size, (0,0,0,0)) # panel_draw = ImageDraw.Draw(overlay) # panel_draw.rectangle( # [stats_start_x - 5, stats_start_y - 5, frame_width - 5, panel_y2], # fill=(100, 100, 100, 128) # 半透明灰色 # ) # pil_image = Image.alpha_composite(pil_image.convert('RGBA'), overlay) # draw = ImageDraw.Draw(pil_image) # 如果用了 alpha_composite, 需要重新获取 draw 对象 # 或者简单不透明背景 # draw.rectangle( # [stats_start_x - 5, stats_start_y - 5, frame_width - 5, panel_y2], # fill=self._hex_to_rgb(self.config.get('stats_panel_bg_color_hex', "#808080")) # 例如灰色背景 # ) for i, line in enumerate(stats_text_lines): text_pos = (stats_start_x, stats_start_y + i * line_height) draw.text(text_pos, line, font=self.font, fill=stats_text_color) final_annotated_frame = cv2.cvtColor(np.array(pil_image.convert('RGB')), cv2.COLOR_RGB2BGR) return final_annotated_frame def detect_and_draw_count(self, frame: np.ndarray, conf: float = -1.0) -> np.ndarray: """执行单帧检测、跟踪、计数并绘制结果(包含类别过滤)。""" if conf == -1.0: # 优先从 detection_settings 中获取,其次是顶层config,最后是硬编码默认值 effective_conf = float( self.detection_settings.get('default_confidence_threshold', self.config.get('default_confidence_threshold', 0.8)) ) else: effective_conf = conf try: # 1. 执行检测 detections = self.model.predict(frame, threshold=effective_conf) # 处理 detections 为 None 或空的情况 if detections is None or len(detections) == 0: detections = sv.Detections.empty() annotated_frame = self._draw_frame(frame, detections) self.last_annotated_frame = annotated_frame.copy() # 新增 return annotated_frame # 新增:根据配置过滤检测到的类别 if detections is not None and len(detections) > 0: keep_indices = [] for i, class_id in enumerate(detections.class_id): if class_id < len(self.VISDRONE_CLASSES): # 确保 class_id 有效 class_name = self.VISDRONE_CLASSES[class_id] if self._active_classes_lookup.get(class_name, True): # 默认为 True keep_indices.append(i) else: print(f"警告: 检测到无效的 class_id {class_id},超出了已知类别范围。") if not keep_indices: detections = sv.Detections.empty() else: detections = detections[keep_indices] # 如果过滤后没有检测结果 if len(detections) == 0: annotated_frame = self._draw_frame(frame, sv.Detections.empty()) self.last_annotated_frame = annotated_frame.copy() # 新增 return annotated_frame # 2. 执行跟踪 (只对过滤后的结果进行跟踪) detections = self.tracker.update_with_detections(detections) # 3. 更新计数器 (只对过滤并跟踪后的结果进行计数) self._update_counter(detections) # 4. 绘制结果 annotated_frame = self._draw_frame(frame, detections) self.last_annotated_frame = annotated_frame.copy() # 新增 return annotated_frame except Exception as e: print(f"处理帧时发生错误: {e}") if frame is not None: self.last_annotated_frame = frame.copy() # 新增 else: self.last_annotated_frame = None # 新增 return frame