Files
AlgorithmCollection/manager.py
2025-12-02 16:43:56 +08:00

116 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from collections import defaultdict
from config import MODEL_CONFIGS
from detect import YOLODetector
class UnifiedDetectionManager:
"""统一检测管理器"""
def __init__(self):
self.detectors = {} # 检测器实例
self.type_to_model = {} # 类别到模型映射
self.loaded_models = [] # 已加载模型
self.type_to_id = {} # 全局类别ID映射
self._load_models()
def _load_models(self):
"""加载所有模型"""
if not MODEL_CONFIGS:
raise ValueError("模型配置为空")
for model_name, config in MODEL_CONFIGS.items():
try:
model_path = config["model_path"]
if not os.path.exists(model_path):
print(f"跳过 {model_name}: 模型文件不存在 - {model_path}")
continue
# 创建检测器自动传递新增的enable_primary配置
detector = YOLODetector(
model_path=model_path,
params=config["params"],
type_to_id=config["type_to_id"]
)
# 保存状态
self.detectors[model_name] = detector
self.loaded_models.append(model_name)
# 建立映射
for det_type in config["types"]:
det_type_lower = det_type.lower()
if det_type_lower in self.type_to_model:
print(f"警告: 类别 '{det_type}' 映射冲突")
self.type_to_model[det_type_lower] = model_name
self.type_to_id[det_type_lower] = config["type_to_id"][det_type_lower]
print(f"加载成功: {model_name}")
except Exception as e:
print(f"加载失败 {model_name}: {str(e)}")
continue
print(f"模型加载完成: {len(self.loaded_models)}/{len(MODEL_CONFIGS)}")
print(f"支持类别: {list(self.type_to_model.keys())}")
def parse_types(self, types_str):
"""解析检测类型"""
if not types_str:
raise ValueError("检测类型为空")
# 清理输入
requested_types = list(set(t.strip().lower() for t in types_str.split(',') if t.strip()))
# 按模型分组
model_type_map = defaultdict(list)
for det_type in requested_types:
if det_type in self.type_to_model:
model_name = self.type_to_model[det_type]
model_type_map[model_name].append(det_type)
else:
print(f"忽略未知类别: {det_type}")
if not model_type_map:
raise ValueError("无有效检测类别")
return model_type_map
def detect(self, img_path, detection_types):
"""执行检测"""
if not os.path.exists(img_path):
raise FileNotFoundError(f"图像不存在: {img_path}")
# 解析类型
model_type_map = self.parse_types(detection_types)
# 执行检测自动适配enable_primary配置
all_results = []
for model_name, target_types in model_type_map.items():
if model_name not in self.detectors:
continue
print(f"检测: {model_name} -> {target_types}")
try:
results = self.detectors[model_name].detect(img_path, target_types)
all_results.extend(results)
# 获取详细统计信息
stats = self.detectors[model_name].get_detection_stats()
print(f" {model_name}详细统计: {stats}")
except Exception as e:
print(f"检测失败 {model_name}: {str(e)}")
print(f"检测完成: 总共 {len(all_results)} 个结果")
return all_results
def get_available_info(self):
"""获取可用信息"""
return {
"loaded_models": self.loaded_models,
"supported_types": list(self.type_to_model.keys()),
"type_to_model": self.type_to_model
}