import os from collections import defaultdict from config import MODEL_CONFIGS from detect import YOLODetector class UnifiedDetectionManager: """统一检测管理器""" def __init__(self): self.detectors = {} # 检测器实例 self.type_to_model = {} # 类别到模型映射 self.loaded_models = [] # 已加载模型 self.type_to_id = {} # 全局类别ID映射 self._load_models() def _load_models(self): """加载所有模型""" if not MODEL_CONFIGS: raise ValueError("模型配置为空") for model_name, config in MODEL_CONFIGS.items(): try: model_path = config["model_path"] if not os.path.exists(model_path): print(f"跳过 {model_name}: 模型文件不存在 - {model_path}") continue # 创建检测器(自动传递新增的enable_primary配置) detector = YOLODetector( model_path=model_path, params=config["params"], type_to_id=config["type_to_id"] ) # 保存状态 self.detectors[model_name] = detector self.loaded_models.append(model_name) # 建立映射 for det_type in config["types"]: det_type_lower = det_type.lower() if det_type_lower in self.type_to_model: print(f"警告: 类别 '{det_type}' 映射冲突") self.type_to_model[det_type_lower] = model_name self.type_to_id[det_type_lower] = config["type_to_id"][det_type_lower] print(f"加载成功: {model_name}") except Exception as e: print(f"加载失败 {model_name}: {str(e)}") continue print(f"模型加载完成: {len(self.loaded_models)}/{len(MODEL_CONFIGS)}") print(f"支持类别: {list(self.type_to_model.keys())}") def parse_types(self, types_str): """解析检测类型""" if not types_str: raise ValueError("检测类型为空") # 清理输入 requested_types = list(set(t.strip().lower() for t in types_str.split(',') if t.strip())) # 按模型分组 model_type_map = defaultdict(list) for det_type in requested_types: if det_type in self.type_to_model: model_name = self.type_to_model[det_type] model_type_map[model_name].append(det_type) else: print(f"忽略未知类别: {det_type}") if not model_type_map: raise ValueError("无有效检测类别") return model_type_map def detect(self, img_path, detection_types): """执行检测""" if not os.path.exists(img_path): raise FileNotFoundError(f"图像不存在: {img_path}") # 解析类型 model_type_map = self.parse_types(detection_types) # 执行检测(自动适配enable_primary配置) all_results = [] for model_name, target_types in model_type_map.items(): if model_name not in self.detectors: continue print(f"检测: {model_name} -> {target_types}") try: results = self.detectors[model_name].detect(img_path, target_types) all_results.extend(results) # 获取详细统计信息 stats = self.detectors[model_name].get_detection_stats() print(f" {model_name}详细统计: {stats}") except Exception as e: print(f"检测失败 {model_name}: {str(e)}") print(f"检测完成: 总共 {len(all_results)} 个结果") return all_results def get_available_info(self): """获取可用信息""" return { "loaded_models": self.loaded_models, "supported_types": list(self.type_to_model.keys()), "type_to_model": self.type_to_model }