中煤科工算法合集
This commit is contained in:
116
manager.py
Normal file
116
manager.py
Normal file
@ -0,0 +1,116 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
from config import MODEL_CONFIGS
|
||||
from detect import YOLODetector
|
||||
|
||||
|
||||
class UnifiedDetectionManager:
|
||||
"""统一检测管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.detectors = {} # 检测器实例
|
||||
self.type_to_model = {} # 类别到模型映射
|
||||
self.loaded_models = [] # 已加载模型
|
||||
self.type_to_id = {} # 全局类别ID映射
|
||||
|
||||
self._load_models()
|
||||
|
||||
def _load_models(self):
|
||||
"""加载所有模型"""
|
||||
if not MODEL_CONFIGS:
|
||||
raise ValueError("模型配置为空")
|
||||
|
||||
for model_name, config in MODEL_CONFIGS.items():
|
||||
try:
|
||||
model_path = config["model_path"]
|
||||
if not os.path.exists(model_path):
|
||||
print(f"跳过 {model_name}: 模型文件不存在 - {model_path}")
|
||||
continue
|
||||
|
||||
# 创建检测器(自动传递新增的enable_primary配置)
|
||||
detector = YOLODetector(
|
||||
model_path=model_path,
|
||||
params=config["params"],
|
||||
type_to_id=config["type_to_id"]
|
||||
)
|
||||
|
||||
# 保存状态
|
||||
self.detectors[model_name] = detector
|
||||
self.loaded_models.append(model_name)
|
||||
|
||||
# 建立映射
|
||||
for det_type in config["types"]:
|
||||
det_type_lower = det_type.lower()
|
||||
if det_type_lower in self.type_to_model:
|
||||
print(f"警告: 类别 '{det_type}' 映射冲突")
|
||||
self.type_to_model[det_type_lower] = model_name
|
||||
self.type_to_id[det_type_lower] = config["type_to_id"][det_type_lower]
|
||||
|
||||
print(f"加载成功: {model_name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"加载失败 {model_name}: {str(e)}")
|
||||
continue
|
||||
|
||||
print(f"模型加载完成: {len(self.loaded_models)}/{len(MODEL_CONFIGS)}")
|
||||
print(f"支持类别: {list(self.type_to_model.keys())}")
|
||||
|
||||
def parse_types(self, types_str):
|
||||
"""解析检测类型"""
|
||||
if not types_str:
|
||||
raise ValueError("检测类型为空")
|
||||
|
||||
# 清理输入
|
||||
requested_types = list(set(t.strip().lower() for t in types_str.split(',') if t.strip()))
|
||||
|
||||
# 按模型分组
|
||||
model_type_map = defaultdict(list)
|
||||
for det_type in requested_types:
|
||||
if det_type in self.type_to_model:
|
||||
model_name = self.type_to_model[det_type]
|
||||
model_type_map[model_name].append(det_type)
|
||||
else:
|
||||
print(f"忽略未知类别: {det_type}")
|
||||
|
||||
if not model_type_map:
|
||||
raise ValueError("无有效检测类别")
|
||||
|
||||
return model_type_map
|
||||
|
||||
def detect(self, img_path, detection_types):
|
||||
"""执行检测"""
|
||||
if not os.path.exists(img_path):
|
||||
raise FileNotFoundError(f"图像不存在: {img_path}")
|
||||
|
||||
# 解析类型
|
||||
model_type_map = self.parse_types(detection_types)
|
||||
|
||||
# 执行检测(自动适配enable_primary配置)
|
||||
all_results = []
|
||||
for model_name, target_types in model_type_map.items():
|
||||
if model_name not in self.detectors:
|
||||
continue
|
||||
|
||||
print(f"检测: {model_name} -> {target_types}")
|
||||
try:
|
||||
results = self.detectors[model_name].detect(img_path, target_types)
|
||||
all_results.extend(results)
|
||||
|
||||
# 获取详细统计信息
|
||||
stats = self.detectors[model_name].get_detection_stats()
|
||||
print(f" {model_name}详细统计: {stats}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"检测失败 {model_name}: {str(e)}")
|
||||
|
||||
print(f"检测完成: 总共 {len(all_results)} 个结果")
|
||||
return all_results
|
||||
|
||||
def get_available_info(self):
|
||||
"""获取可用信息"""
|
||||
return {
|
||||
"loaded_models": self.loaded_models,
|
||||
"supported_types": list(self.type_to_model.keys()),
|
||||
"type_to_model": self.type_to_model
|
||||
}
|
||||
Reference in New Issue
Block a user