Files
AlgorithmCollection/detect.py
2025-12-02 16:43:56 +08:00

313 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from collections import defaultdict
import cv2
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from ultralytics import YOLO
from config import DEVICE, DEFAULT_IOU, DEFAULT_MIN_SIZE, DEFAULT_POS_THRESH, SLICE_RULES, DEFAULT_CONF
class YOLODetector:
def __init__(self, model_path, params, type_to_id):
# 加载YOLO模型
self.model = YOLO(model_path)
self.model.to(DEVICE)
self.class_names = self.model.names
self.type_to_id = type_to_id
self.params = params
self.enable_primary = params.get("enable_primary", True)
self.primary_conf = params.get("primary_conf", DEFAULT_CONF) # 初级检测阈值
self.secondary_conf = params.get("secondary_conf", DEFAULT_CONF) # 次级检测阈值
self.final_conf = params.get("final_conf", DEFAULT_CONF) # 最终展示阈值
# SAHI模型
self.sahi_model = None
if params["enable_secondary"]:
self.sahi_model = AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path=model_path,
confidence_threshold=self.secondary_conf,
device=DEVICE
)
# 统计
self.stats = defaultdict(int)
def get_adaptive_slice(self, total_pixels):
"""自适应切片参数"""
for pixel_thresh, (size, overlap) in SLICE_RULES:
if total_pixels > pixel_thresh:
return size, overlap
return self.params["slice_size"], self.params["overlap_ratio"]
def multi_scale_detect(self, img_path):
"""多尺度检测(使用模型专属初级阈值)"""
detections = []
img = cv2.imread(img_path)
h, w = img.shape[:2]
for scale in self.params["multi_scales"]:
if scale == 1.0:
# 原尺度检测
results = self.model(
img_path,
conf=self.primary_conf, # 模型专属初级阈值
device=DEVICE,
classes=self.target_ids,
verbose=False
)
else:
# 缩放检测
nw, nh = int(w * scale), int(h * scale)
scaled_img = cv2.resize(img, (nw, nh))
temp_path = f"temp_scale_{scale}.jpg"
cv2.imwrite(temp_path, scaled_img)
results = self.model(
temp_path,
conf=self.primary_conf, # 模型专属初级阈值
device=DEVICE,
classes=self.target_ids,
verbose=False
)
os.remove(temp_path)
# 解析结果核心修复增加对result.boxes为None的判断
for result in results:
# 检查boxes是否存在且非空
if result.boxes is None:
continue
for box in result.boxes:
bbox = box.xyxy[0].tolist()
if scale != 1.0:
bbox = [coord / scale for coord in bbox]
detections.append({
"box": bbox,
"conf": box.conf[0].item(),
"class": box.cls[0].item(),
"class_name": self.class_names[int(box.cls[0])],
"source": "primary"
})
return detections
def primary_detect(self, img_path):
"""初次检测(使用模型专属初级阈值)- 新增enable_primary判断"""
# 新增:如果禁用一级检测,直接返回空列表
if not self.enable_primary:
self.stats["primary"] = 0
print(" 一级检测已禁用,跳过初级检测")
return []
if self.params["enable_multi_scale"]:
detections = self.multi_scale_detect(img_path)
else:
results = self.model(
img_path,
conf=self.primary_conf, # 模型专属初级阈值
device=DEVICE,
classes=self.target_ids,
verbose=False
)
# 解析结果核心修复增加对result.boxes为None的判断
detections = []
for result in results:
# 检查boxes是否存在且非空
if result.boxes is None:
continue
for box in result.boxes:
detections.append({
"box": box.xyxy[0].tolist(),
"conf": box.conf[0].item(),
"class": box.cls[0].item(),
"class_name": self.class_names[int(box.cls[0])],
"source": "primary"
})
self.stats["primary"] = len(detections)
return detections
def secondary_detect(self, img_path):
"""SAHI切片检测已在初始化时使用模型专属次级阈值"""
if not self.params["enable_secondary"] or not self.sahi_model:
return []
img = cv2.imread(img_path)
h, w = img.shape[:2]
total_pixels = w * h
slice_size, overlap = self.get_adaptive_slice(total_pixels)
# SAHI切片预测
sliced_results = get_sliced_prediction(
img_path,
self.sahi_model,
slice_height=slice_size,
slice_width=slice_size,
overlap_height_ratio=overlap,
overlap_width_ratio=overlap,
verbose=0
)
detections = []
for obj in sliced_results.object_prediction_list:
if self.target_ids and obj.category.id not in self.target_ids:
continue
bbox = obj.bbox.to_xyxy()
bw, bh = bbox[2] - bbox[0], bbox[3] - bbox[1]
if bw >= DEFAULT_MIN_SIZE and bh >= DEFAULT_MIN_SIZE:
detections.append({
"box": bbox,
"conf": obj.score.value,
"class": obj.category.id,
"class_name": obj.category.name,
"source": "secondary"
})
self.stats["secondary"] = len(detections)
return detections
@staticmethod
def calculate_iou(box1, box2):
"""计算IoU"""
x11, y11, x21, y21 = box1
x12, y12, x22, y22 = box2
inter_x1 = max(x11, x12)
inter_y1 = max(y11, y12)
inter_x2 = min(x21, x22)
inter_y2 = min(y21, y22)
inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
area1 = (x21 - x11) * (y21 - y11)
area2 = (x22 - x12) * (y22 - y12)
union_area = area1 + area2 - inter_area
return inter_area / union_area if union_area > 0 else 0
def merge_detections(self, primary_dets, secondary_dets):
"""融合检测结果"""
if not primary_dets:
return secondary_dets
if not secondary_dets:
return primary_dets
# 加权置信度
all_dets = []
for det in primary_dets:
det["weighted_conf"] = det["conf"] * self.params["weight_primary"]
all_dets.append(det)
for det in secondary_dets:
det["weighted_conf"] = det["conf"] * self.params["weight_secondary"]
all_dets.append(det)
# 按类别分组融合
class_groups = defaultdict(list)
for det in all_dets:
class_groups[det["class"]].append(det)
merged = []
for cls_id, cls_dets in class_groups.items():
cls_dets.sort(key=lambda x: x["weighted_conf"], reverse=True)
suppressed = [False] * len(cls_dets)
for i in range(len(cls_dets)):
if suppressed[i]:
continue
merged.append(cls_dets[i])
for j in range(i + 1, len(cls_dets)):
if not suppressed[j] and self.calculate_iou(cls_dets[i]["box"], cls_dets[j]["box"]) > DEFAULT_IOU:
suppressed[j] = True
self.stats["merged"] = len(merged)
return merged
def post_process(self, detections):
"""后处理(使用模型专属最终阈值)"""
# 置信度过滤:模型专属最终阈值
filtered = [det for det in detections if det["conf"] >= self.final_conf]
# 位置去重
final_dets = []
for curr_det in filtered:
curr_cx = (curr_det["box"][0] + curr_det["box"][2]) / 2
curr_cy = (curr_det["box"][1] + curr_det["box"][3]) / 2
curr_cls = curr_det["class"]
duplicate = False
for idx, exist_det in enumerate(final_dets):
if exist_det["class"] != curr_cls:
continue
exist_cx = (exist_det["box"][0] + exist_det["box"][2]) / 2
exist_cy = (exist_det["box"][1] + exist_det["box"][3]) / 2
dist = np.sqrt((curr_cx - exist_cx) **2 + (curr_cy - exist_cy)** 2)
if dist < DEFAULT_POS_THRESH:
duplicate = True
if curr_det["conf"] > exist_det["conf"]:
final_dets[idx] = curr_det
break
if not duplicate:
final_dets.append(curr_det)
self.stats["final"] = len(final_dets)
return final_dets
def format_results(self, detections):
"""格式化结果"""
formatted = []
for det in detections:
x1, y1, x2, y2 = det["box"]
formatted.append({
"type": det["class_name"],
"size": [int(round(x2 - x1)), int(round(y2 - y1))],
"leftTopPoint": [int(round(x1)), int(round(y1))],
"score": round(det["conf"], 4),
})
return formatted
def get_detection_stats(self):
"""获取检测统计信息"""
return dict(self.stats)
def detect(self, img_path, target_types=None):
"""完整检测流程"""
# 重置统计
self.stats = defaultdict(int)
# 设置目标类别
if target_types:
self.target_ids = [self.type_to_id[cls] for cls in target_types if cls in self.type_to_id]
else:
self.target_ids = None
# 执行检测
primary_dets = self.primary_detect(img_path)
print(f" 初级检测后: {self.stats['primary']} 个目标")
if self.params["enable_secondary"]:
secondary_dets = self.secondary_detect(img_path)
print(f" 次级检测后: {self.stats['secondary']} 个目标")
merged_dets = self.merge_detections(primary_dets, secondary_dets)
print(f" 融合去重后: {self.stats['merged']} 个目标")
else:
merged_dets = primary_dets
print(f" 次级检测未启用")
# 后处理
processed_dets = self.post_process(merged_dets)
print(f" 过滤低置信度后: {self.stats['final']} 个目标")
print(" 最终检测目标详情:")
for idx, det in enumerate(processed_dets, 1):
print(f" 目标{idx} - 类型:{det['class_name']},置信度:{det['conf']:.4f}")
return self.format_results(processed_dets)