中煤科工算法合集
This commit is contained in:
313
detect.py
Normal file
313
detect.py
Normal file
@ -0,0 +1,313 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from sahi import AutoDetectionModel
|
||||
from sahi.predict import get_sliced_prediction
|
||||
from ultralytics import YOLO
|
||||
|
||||
from config import DEVICE, DEFAULT_IOU, DEFAULT_MIN_SIZE, DEFAULT_POS_THRESH, SLICE_RULES, DEFAULT_CONF
|
||||
|
||||
|
||||
class YOLODetector:
|
||||
def __init__(self, model_path, params, type_to_id):
|
||||
# 加载YOLO模型
|
||||
self.model = YOLO(model_path)
|
||||
self.model.to(DEVICE)
|
||||
self.class_names = self.model.names
|
||||
self.type_to_id = type_to_id
|
||||
|
||||
self.params = params
|
||||
self.enable_primary = params.get("enable_primary", True)
|
||||
self.primary_conf = params.get("primary_conf", DEFAULT_CONF) # 初级检测阈值
|
||||
self.secondary_conf = params.get("secondary_conf", DEFAULT_CONF) # 次级检测阈值
|
||||
self.final_conf = params.get("final_conf", DEFAULT_CONF) # 最终展示阈值
|
||||
|
||||
# SAHI模型
|
||||
self.sahi_model = None
|
||||
if params["enable_secondary"]:
|
||||
self.sahi_model = AutoDetectionModel.from_pretrained(
|
||||
model_type='yolov8',
|
||||
model_path=model_path,
|
||||
confidence_threshold=self.secondary_conf,
|
||||
device=DEVICE
|
||||
)
|
||||
|
||||
# 统计
|
||||
self.stats = defaultdict(int)
|
||||
|
||||
def get_adaptive_slice(self, total_pixels):
|
||||
"""自适应切片参数"""
|
||||
for pixel_thresh, (size, overlap) in SLICE_RULES:
|
||||
if total_pixels > pixel_thresh:
|
||||
return size, overlap
|
||||
return self.params["slice_size"], self.params["overlap_ratio"]
|
||||
|
||||
def multi_scale_detect(self, img_path):
|
||||
"""多尺度检测(使用模型专属初级阈值)"""
|
||||
detections = []
|
||||
img = cv2.imread(img_path)
|
||||
h, w = img.shape[:2]
|
||||
|
||||
for scale in self.params["multi_scales"]:
|
||||
if scale == 1.0:
|
||||
# 原尺度检测
|
||||
results = self.model(
|
||||
img_path,
|
||||
conf=self.primary_conf, # 模型专属初级阈值
|
||||
device=DEVICE,
|
||||
classes=self.target_ids,
|
||||
verbose=False
|
||||
)
|
||||
else:
|
||||
# 缩放检测
|
||||
nw, nh = int(w * scale), int(h * scale)
|
||||
scaled_img = cv2.resize(img, (nw, nh))
|
||||
temp_path = f"temp_scale_{scale}.jpg"
|
||||
cv2.imwrite(temp_path, scaled_img)
|
||||
|
||||
results = self.model(
|
||||
temp_path,
|
||||
conf=self.primary_conf, # 模型专属初级阈值
|
||||
device=DEVICE,
|
||||
classes=self.target_ids,
|
||||
verbose=False
|
||||
)
|
||||
os.remove(temp_path)
|
||||
|
||||
# 解析结果(核心修复:增加对result.boxes为None的判断)
|
||||
for result in results:
|
||||
# 检查boxes是否存在且非空
|
||||
if result.boxes is None:
|
||||
continue
|
||||
for box in result.boxes:
|
||||
bbox = box.xyxy[0].tolist()
|
||||
if scale != 1.0:
|
||||
bbox = [coord / scale for coord in bbox]
|
||||
|
||||
detections.append({
|
||||
"box": bbox,
|
||||
"conf": box.conf[0].item(),
|
||||
"class": box.cls[0].item(),
|
||||
"class_name": self.class_names[int(box.cls[0])],
|
||||
"source": "primary"
|
||||
})
|
||||
|
||||
return detections
|
||||
|
||||
def primary_detect(self, img_path):
|
||||
"""初次检测(使用模型专属初级阈值)- 新增enable_primary判断"""
|
||||
# 新增:如果禁用一级检测,直接返回空列表
|
||||
if not self.enable_primary:
|
||||
self.stats["primary"] = 0
|
||||
print(" 一级检测已禁用,跳过初级检测")
|
||||
return []
|
||||
|
||||
if self.params["enable_multi_scale"]:
|
||||
detections = self.multi_scale_detect(img_path)
|
||||
else:
|
||||
results = self.model(
|
||||
img_path,
|
||||
conf=self.primary_conf, # 模型专属初级阈值
|
||||
device=DEVICE,
|
||||
classes=self.target_ids,
|
||||
verbose=False
|
||||
)
|
||||
# 解析结果(核心修复:增加对result.boxes为None的判断)
|
||||
detections = []
|
||||
for result in results:
|
||||
# 检查boxes是否存在且非空
|
||||
if result.boxes is None:
|
||||
continue
|
||||
for box in result.boxes:
|
||||
detections.append({
|
||||
"box": box.xyxy[0].tolist(),
|
||||
"conf": box.conf[0].item(),
|
||||
"class": box.cls[0].item(),
|
||||
"class_name": self.class_names[int(box.cls[0])],
|
||||
"source": "primary"
|
||||
})
|
||||
|
||||
self.stats["primary"] = len(detections)
|
||||
return detections
|
||||
|
||||
def secondary_detect(self, img_path):
|
||||
"""SAHI切片检测(已在初始化时使用模型专属次级阈值)"""
|
||||
if not self.params["enable_secondary"] or not self.sahi_model:
|
||||
return []
|
||||
|
||||
img = cv2.imread(img_path)
|
||||
h, w = img.shape[:2]
|
||||
total_pixels = w * h
|
||||
slice_size, overlap = self.get_adaptive_slice(total_pixels)
|
||||
|
||||
# SAHI切片预测
|
||||
sliced_results = get_sliced_prediction(
|
||||
img_path,
|
||||
self.sahi_model,
|
||||
slice_height=slice_size,
|
||||
slice_width=slice_size,
|
||||
overlap_height_ratio=overlap,
|
||||
overlap_width_ratio=overlap,
|
||||
verbose=0
|
||||
)
|
||||
|
||||
detections = []
|
||||
for obj in sliced_results.object_prediction_list:
|
||||
if self.target_ids and obj.category.id not in self.target_ids:
|
||||
continue
|
||||
|
||||
bbox = obj.bbox.to_xyxy()
|
||||
bw, bh = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||
|
||||
if bw >= DEFAULT_MIN_SIZE and bh >= DEFAULT_MIN_SIZE:
|
||||
detections.append({
|
||||
"box": bbox,
|
||||
"conf": obj.score.value,
|
||||
"class": obj.category.id,
|
||||
"class_name": obj.category.name,
|
||||
"source": "secondary"
|
||||
})
|
||||
|
||||
self.stats["secondary"] = len(detections)
|
||||
return detections
|
||||
|
||||
@staticmethod
|
||||
def calculate_iou(box1, box2):
|
||||
"""计算IoU"""
|
||||
x11, y11, x21, y21 = box1
|
||||
x12, y12, x22, y22 = box2
|
||||
|
||||
inter_x1 = max(x11, x12)
|
||||
inter_y1 = max(y11, y12)
|
||||
inter_x2 = min(x21, x22)
|
||||
inter_y2 = min(y21, y22)
|
||||
|
||||
inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
|
||||
area1 = (x21 - x11) * (y21 - y11)
|
||||
area2 = (x22 - x12) * (y22 - y12)
|
||||
union_area = area1 + area2 - inter_area
|
||||
|
||||
return inter_area / union_area if union_area > 0 else 0
|
||||
|
||||
def merge_detections(self, primary_dets, secondary_dets):
|
||||
"""融合检测结果"""
|
||||
if not primary_dets:
|
||||
return secondary_dets
|
||||
if not secondary_dets:
|
||||
return primary_dets
|
||||
|
||||
# 加权置信度
|
||||
all_dets = []
|
||||
for det in primary_dets:
|
||||
det["weighted_conf"] = det["conf"] * self.params["weight_primary"]
|
||||
all_dets.append(det)
|
||||
for det in secondary_dets:
|
||||
det["weighted_conf"] = det["conf"] * self.params["weight_secondary"]
|
||||
all_dets.append(det)
|
||||
|
||||
# 按类别分组融合
|
||||
class_groups = defaultdict(list)
|
||||
for det in all_dets:
|
||||
class_groups[det["class"]].append(det)
|
||||
|
||||
merged = []
|
||||
for cls_id, cls_dets in class_groups.items():
|
||||
cls_dets.sort(key=lambda x: x["weighted_conf"], reverse=True)
|
||||
suppressed = [False] * len(cls_dets)
|
||||
|
||||
for i in range(len(cls_dets)):
|
||||
if suppressed[i]:
|
||||
continue
|
||||
merged.append(cls_dets[i])
|
||||
for j in range(i + 1, len(cls_dets)):
|
||||
if not suppressed[j] and self.calculate_iou(cls_dets[i]["box"], cls_dets[j]["box"]) > DEFAULT_IOU:
|
||||
suppressed[j] = True
|
||||
|
||||
self.stats["merged"] = len(merged)
|
||||
return merged
|
||||
|
||||
def post_process(self, detections):
|
||||
"""后处理(使用模型专属最终阈值)"""
|
||||
# 置信度过滤:模型专属最终阈值
|
||||
filtered = [det for det in detections if det["conf"] >= self.final_conf]
|
||||
|
||||
# 位置去重
|
||||
final_dets = []
|
||||
for curr_det in filtered:
|
||||
curr_cx = (curr_det["box"][0] + curr_det["box"][2]) / 2
|
||||
curr_cy = (curr_det["box"][1] + curr_det["box"][3]) / 2
|
||||
curr_cls = curr_det["class"]
|
||||
duplicate = False
|
||||
|
||||
for idx, exist_det in enumerate(final_dets):
|
||||
if exist_det["class"] != curr_cls:
|
||||
continue
|
||||
|
||||
exist_cx = (exist_det["box"][0] + exist_det["box"][2]) / 2
|
||||
exist_cy = (exist_det["box"][1] + exist_det["box"][3]) / 2
|
||||
dist = np.sqrt((curr_cx - exist_cx) **2 + (curr_cy - exist_cy)** 2)
|
||||
|
||||
if dist < DEFAULT_POS_THRESH:
|
||||
duplicate = True
|
||||
if curr_det["conf"] > exist_det["conf"]:
|
||||
final_dets[idx] = curr_det
|
||||
break
|
||||
|
||||
if not duplicate:
|
||||
final_dets.append(curr_det)
|
||||
|
||||
self.stats["final"] = len(final_dets)
|
||||
return final_dets
|
||||
|
||||
def format_results(self, detections):
|
||||
"""格式化结果"""
|
||||
formatted = []
|
||||
for det in detections:
|
||||
x1, y1, x2, y2 = det["box"]
|
||||
formatted.append({
|
||||
"type": det["class_name"],
|
||||
"size": [int(round(x2 - x1)), int(round(y2 - y1))],
|
||||
"leftTopPoint": [int(round(x1)), int(round(y1))],
|
||||
"score": round(det["conf"], 4),
|
||||
})
|
||||
return formatted
|
||||
|
||||
def get_detection_stats(self):
|
||||
"""获取检测统计信息"""
|
||||
return dict(self.stats)
|
||||
|
||||
def detect(self, img_path, target_types=None):
|
||||
"""完整检测流程"""
|
||||
# 重置统计
|
||||
self.stats = defaultdict(int)
|
||||
|
||||
# 设置目标类别
|
||||
if target_types:
|
||||
self.target_ids = [self.type_to_id[cls] for cls in target_types if cls in self.type_to_id]
|
||||
else:
|
||||
self.target_ids = None
|
||||
|
||||
# 执行检测
|
||||
primary_dets = self.primary_detect(img_path)
|
||||
print(f" 初级检测后: {self.stats['primary']} 个目标")
|
||||
|
||||
if self.params["enable_secondary"]:
|
||||
secondary_dets = self.secondary_detect(img_path)
|
||||
print(f" 次级检测后: {self.stats['secondary']} 个目标")
|
||||
merged_dets = self.merge_detections(primary_dets, secondary_dets)
|
||||
print(f" 融合去重后: {self.stats['merged']} 个目标")
|
||||
else:
|
||||
merged_dets = primary_dets
|
||||
print(f" 次级检测未启用")
|
||||
|
||||
# 后处理
|
||||
processed_dets = self.post_process(merged_dets)
|
||||
print(f" 过滤低置信度后: {self.stats['final']} 个目标")
|
||||
|
||||
print(" 最终检测目标详情:")
|
||||
for idx, det in enumerate(processed_dets, 1):
|
||||
print(f" 目标{idx} - 类型:{det['class_name']},置信度:{det['conf']:.4f}")
|
||||
|
||||
return self.format_results(processed_dets)
|
||||
Reference in New Issue
Block a user