116 lines
4.0 KiB
Python
116 lines
4.0 KiB
Python
import os
|
||
from collections import defaultdict
|
||
|
||
from config import MODEL_CONFIGS
|
||
from detect import YOLODetector
|
||
|
||
|
||
class UnifiedDetectionManager:
|
||
"""统一检测管理器"""
|
||
|
||
def __init__(self):
|
||
self.detectors = {} # 检测器实例
|
||
self.type_to_model = {} # 类别到模型映射
|
||
self.loaded_models = [] # 已加载模型
|
||
self.type_to_id = {} # 全局类别ID映射
|
||
|
||
self._load_models()
|
||
|
||
def _load_models(self):
|
||
"""加载所有模型"""
|
||
if not MODEL_CONFIGS:
|
||
raise ValueError("模型配置为空")
|
||
|
||
for model_name, config in MODEL_CONFIGS.items():
|
||
try:
|
||
model_path = config["model_path"]
|
||
if not os.path.exists(model_path):
|
||
print(f"跳过 {model_name}: 模型文件不存在 - {model_path}")
|
||
continue
|
||
|
||
# 创建检测器(自动传递新增的enable_primary配置)
|
||
detector = YOLODetector(
|
||
model_path=model_path,
|
||
params=config["params"],
|
||
type_to_id=config["type_to_id"]
|
||
)
|
||
|
||
# 保存状态
|
||
self.detectors[model_name] = detector
|
||
self.loaded_models.append(model_name)
|
||
|
||
# 建立映射
|
||
for det_type in config["types"]:
|
||
det_type_lower = det_type.lower()
|
||
if det_type_lower in self.type_to_model:
|
||
print(f"警告: 类别 '{det_type}' 映射冲突")
|
||
self.type_to_model[det_type_lower] = model_name
|
||
self.type_to_id[det_type_lower] = config["type_to_id"][det_type_lower]
|
||
|
||
print(f"加载成功: {model_name}")
|
||
|
||
except Exception as e:
|
||
print(f"加载失败 {model_name}: {str(e)}")
|
||
continue
|
||
|
||
print(f"模型加载完成: {len(self.loaded_models)}/{len(MODEL_CONFIGS)}")
|
||
print(f"支持类别: {list(self.type_to_model.keys())}")
|
||
|
||
def parse_types(self, types_str):
|
||
"""解析检测类型"""
|
||
if not types_str:
|
||
raise ValueError("检测类型为空")
|
||
|
||
# 清理输入
|
||
requested_types = list(set(t.strip().lower() for t in types_str.split(',') if t.strip()))
|
||
|
||
# 按模型分组
|
||
model_type_map = defaultdict(list)
|
||
for det_type in requested_types:
|
||
if det_type in self.type_to_model:
|
||
model_name = self.type_to_model[det_type]
|
||
model_type_map[model_name].append(det_type)
|
||
else:
|
||
print(f"忽略未知类别: {det_type}")
|
||
|
||
if not model_type_map:
|
||
raise ValueError("无有效检测类别")
|
||
|
||
return model_type_map
|
||
|
||
def detect(self, img_path, detection_types):
|
||
"""执行检测"""
|
||
if not os.path.exists(img_path):
|
||
raise FileNotFoundError(f"图像不存在: {img_path}")
|
||
|
||
# 解析类型
|
||
model_type_map = self.parse_types(detection_types)
|
||
|
||
# 执行检测(自动适配enable_primary配置)
|
||
all_results = []
|
||
for model_name, target_types in model_type_map.items():
|
||
if model_name not in self.detectors:
|
||
continue
|
||
|
||
print(f"检测: {model_name} -> {target_types}")
|
||
try:
|
||
results = self.detectors[model_name].detect(img_path, target_types)
|
||
all_results.extend(results)
|
||
|
||
# 获取详细统计信息
|
||
stats = self.detectors[model_name].get_detection_stats()
|
||
print(f" {model_name}详细统计: {stats}")
|
||
|
||
except Exception as e:
|
||
print(f"检测失败 {model_name}: {str(e)}")
|
||
|
||
print(f"检测完成: 总共 {len(all_results)} 个结果")
|
||
return all_results
|
||
|
||
def get_available_info(self):
|
||
"""获取可用信息"""
|
||
return {
|
||
"loaded_models": self.loaded_models,
|
||
"supported_types": list(self.type_to_model.keys()),
|
||
"type_to_model": self.type_to_model
|
||
} |