Files
AlgorithmCollection/manager.py

116 lines
4.0 KiB
Python
Raw Permalink Normal View History

2025-12-02 16:43:56 +08:00
import os
from collections import defaultdict
from config import MODEL_CONFIGS
from detect import YOLODetector
class UnifiedDetectionManager:
"""统一检测管理器"""
def __init__(self):
self.detectors = {} # 检测器实例
self.type_to_model = {} # 类别到模型映射
self.loaded_models = [] # 已加载模型
self.type_to_id = {} # 全局类别ID映射
self._load_models()
def _load_models(self):
"""加载所有模型"""
if not MODEL_CONFIGS:
raise ValueError("模型配置为空")
for model_name, config in MODEL_CONFIGS.items():
try:
model_path = config["model_path"]
if not os.path.exists(model_path):
print(f"跳过 {model_name}: 模型文件不存在 - {model_path}")
continue
# 创建检测器自动传递新增的enable_primary配置
detector = YOLODetector(
model_path=model_path,
params=config["params"],
type_to_id=config["type_to_id"]
)
# 保存状态
self.detectors[model_name] = detector
self.loaded_models.append(model_name)
# 建立映射
for det_type in config["types"]:
det_type_lower = det_type.lower()
if det_type_lower in self.type_to_model:
print(f"警告: 类别 '{det_type}' 映射冲突")
self.type_to_model[det_type_lower] = model_name
self.type_to_id[det_type_lower] = config["type_to_id"][det_type_lower]
print(f"加载成功: {model_name}")
except Exception as e:
print(f"加载失败 {model_name}: {str(e)}")
continue
print(f"模型加载完成: {len(self.loaded_models)}/{len(MODEL_CONFIGS)}")
print(f"支持类别: {list(self.type_to_model.keys())}")
def parse_types(self, types_str):
"""解析检测类型"""
if not types_str:
raise ValueError("检测类型为空")
# 清理输入
requested_types = list(set(t.strip().lower() for t in types_str.split(',') if t.strip()))
# 按模型分组
model_type_map = defaultdict(list)
for det_type in requested_types:
if det_type in self.type_to_model:
model_name = self.type_to_model[det_type]
model_type_map[model_name].append(det_type)
else:
print(f"忽略未知类别: {det_type}")
if not model_type_map:
raise ValueError("无有效检测类别")
return model_type_map
def detect(self, img_path, detection_types):
"""执行检测"""
if not os.path.exists(img_path):
raise FileNotFoundError(f"图像不存在: {img_path}")
# 解析类型
model_type_map = self.parse_types(detection_types)
# 执行检测自动适配enable_primary配置
all_results = []
for model_name, target_types in model_type_map.items():
if model_name not in self.detectors:
continue
print(f"检测: {model_name} -> {target_types}")
try:
results = self.detectors[model_name].detect(img_path, target_types)
all_results.extend(results)
# 获取详细统计信息
stats = self.detectors[model_name].get_detection_stats()
print(f" {model_name}详细统计: {stats}")
except Exception as e:
print(f"检测失败 {model_name}: {str(e)}")
print(f"检测完成: 总共 {len(all_results)} 个结果")
return all_results
def get_available_info(self):
"""获取可用信息"""
return {
"loaded_models": self.loaded_models,
"supported_types": list(self.type_to_model.keys()),
"type_to_model": self.type_to_model
}