Files
AI_agent_detect/detect.py

313 lines
11 KiB
Python
Raw Normal View History

2025-12-02 17:16:26 +08:00
import os
from collections import defaultdict
import cv2
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from ultralytics import YOLO
from config import DEVICE, DEFAULT_IOU, DEFAULT_MIN_SIZE, DEFAULT_POS_THRESH, SLICE_RULES, DEFAULT_CONF
class YOLODetector:
def __init__(self, model_path, params, type_to_id):
# 加载YOLO模型
self.model = YOLO(model_path)
self.model.to(DEVICE)
self.class_names = self.model.names
self.type_to_id = type_to_id
self.params = params
self.enable_primary = params.get("enable_primary", True)
self.primary_conf = params.get("primary_conf", DEFAULT_CONF) # 初级检测阈值
self.secondary_conf = params.get("secondary_conf", DEFAULT_CONF) # 次级检测阈值
self.final_conf = params.get("final_conf", DEFAULT_CONF) # 最终展示阈值
# SAHI模型
self.sahi_model = None
if params["enable_secondary"]:
self.sahi_model = AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path=model_path,
confidence_threshold=self.secondary_conf,
device=DEVICE
)
# 统计
self.stats = defaultdict(int)
def get_adaptive_slice(self, total_pixels):
"""自适应切片参数"""
for pixel_thresh, (size, overlap) in SLICE_RULES:
if total_pixels > pixel_thresh:
return size, overlap
return self.params["slice_size"], self.params["overlap_ratio"]
def multi_scale_detect(self, img_path):
"""多尺度检测(使用模型专属初级阈值)"""
detections = []
img = cv2.imread(img_path)
h, w = img.shape[:2]
for scale in self.params["multi_scales"]:
if scale == 1.0:
# 原尺度检测
results = self.model(
img_path,
conf=self.primary_conf, # 模型专属初级阈值
device=DEVICE,
classes=self.target_ids,
verbose=False
)
else:
# 缩放检测
nw, nh = int(w * scale), int(h * scale)
scaled_img = cv2.resize(img, (nw, nh))
temp_path = f"temp_scale_{scale}.jpg"
cv2.imwrite(temp_path, scaled_img)
results = self.model(
temp_path,
conf=self.primary_conf, # 模型专属初级阈值
device=DEVICE,
classes=self.target_ids,
verbose=False
)
os.remove(temp_path)
# 解析结果核心修复增加对result.boxes为None的判断
for result in results:
# 检查boxes是否存在且非空
if result.boxes is None:
continue
for box in result.boxes:
bbox = box.xyxy[0].tolist()
if scale != 1.0:
bbox = [coord / scale for coord in bbox]
detections.append({
"box": bbox,
"conf": box.conf[0].item(),
"class": box.cls[0].item(),
"class_name": self.class_names[int(box.cls[0])],
"source": "primary"
})
return detections
def primary_detect(self, img_path):
"""初次检测(使用模型专属初级阈值)- 新增enable_primary判断"""
# 新增:如果禁用一级检测,直接返回空列表
if not self.enable_primary:
self.stats["primary"] = 0
print(" 一级检测已禁用,跳过初级检测")
return []
if self.params["enable_multi_scale"]:
detections = self.multi_scale_detect(img_path)
else:
results = self.model(
img_path,
conf=self.primary_conf, # 模型专属初级阈值
device=DEVICE,
classes=self.target_ids,
verbose=False
)
# 解析结果核心修复增加对result.boxes为None的判断
detections = []
for result in results:
# 检查boxes是否存在且非空
if result.boxes is None:
continue
for box in result.boxes:
detections.append({
"box": box.xyxy[0].tolist(),
"conf": box.conf[0].item(),
"class": box.cls[0].item(),
"class_name": self.class_names[int(box.cls[0])],
"source": "primary"
})
self.stats["primary"] = len(detections)
return detections
def secondary_detect(self, img_path):
"""SAHI切片检测已在初始化时使用模型专属次级阈值"""
if not self.params["enable_secondary"] or not self.sahi_model:
return []
img = cv2.imread(img_path)
h, w = img.shape[:2]
total_pixels = w * h
slice_size, overlap = self.get_adaptive_slice(total_pixels)
# SAHI切片预测
sliced_results = get_sliced_prediction(
img_path,
self.sahi_model,
slice_height=slice_size,
slice_width=slice_size,
overlap_height_ratio=overlap,
overlap_width_ratio=overlap,
verbose=0
)
detections = []
for obj in sliced_results.object_prediction_list:
if self.target_ids and obj.category.id not in self.target_ids:
continue
bbox = obj.bbox.to_xyxy()
bw, bh = bbox[2] - bbox[0], bbox[3] - bbox[1]
if bw >= DEFAULT_MIN_SIZE and bh >= DEFAULT_MIN_SIZE:
detections.append({
"box": bbox,
"conf": obj.score.value,
"class": obj.category.id,
"class_name": obj.category.name,
"source": "secondary"
})
self.stats["secondary"] = len(detections)
return detections
@staticmethod
def calculate_iou(box1, box2):
"""计算IoU"""
x11, y11, x21, y21 = box1
x12, y12, x22, y22 = box2
inter_x1 = max(x11, x12)
inter_y1 = max(y11, y12)
inter_x2 = min(x21, x22)
inter_y2 = min(y21, y22)
inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
area1 = (x21 - x11) * (y21 - y11)
area2 = (x22 - x12) * (y22 - y12)
union_area = area1 + area2 - inter_area
return inter_area / union_area if union_area > 0 else 0
def merge_detections(self, primary_dets, secondary_dets):
"""融合检测结果"""
if not primary_dets:
return secondary_dets
if not secondary_dets:
return primary_dets
# 加权置信度
all_dets = []
for det in primary_dets:
det["weighted_conf"] = det["conf"] * self.params["weight_primary"]
all_dets.append(det)
for det in secondary_dets:
det["weighted_conf"] = det["conf"] * self.params["weight_secondary"]
all_dets.append(det)
# 按类别分组融合
class_groups = defaultdict(list)
for det in all_dets:
class_groups[det["class"]].append(det)
merged = []
for cls_id, cls_dets in class_groups.items():
cls_dets.sort(key=lambda x: x["weighted_conf"], reverse=True)
suppressed = [False] * len(cls_dets)
for i in range(len(cls_dets)):
if suppressed[i]:
continue
merged.append(cls_dets[i])
for j in range(i + 1, len(cls_dets)):
if not suppressed[j] and self.calculate_iou(cls_dets[i]["box"], cls_dets[j]["box"]) > DEFAULT_IOU:
suppressed[j] = True
self.stats["merged"] = len(merged)
return merged
def post_process(self, detections):
"""后处理(使用模型专属最终阈值)"""
# 置信度过滤:模型专属最终阈值
filtered = [det for det in detections if det["conf"] >= self.final_conf]
# 位置去重
final_dets = []
for curr_det in filtered:
curr_cx = (curr_det["box"][0] + curr_det["box"][2]) / 2
curr_cy = (curr_det["box"][1] + curr_det["box"][3]) / 2
curr_cls = curr_det["class"]
duplicate = False
for idx, exist_det in enumerate(final_dets):
if exist_det["class"] != curr_cls:
continue
exist_cx = (exist_det["box"][0] + exist_det["box"][2]) / 2
exist_cy = (exist_det["box"][1] + exist_det["box"][3]) / 2
dist = np.sqrt((curr_cx - exist_cx) **2 + (curr_cy - exist_cy)** 2)
if dist < DEFAULT_POS_THRESH:
duplicate = True
if curr_det["conf"] > exist_det["conf"]:
final_dets[idx] = curr_det
break
if not duplicate:
final_dets.append(curr_det)
self.stats["final"] = len(final_dets)
return final_dets
def format_results(self, detections):
"""格式化结果"""
formatted = []
for det in detections:
x1, y1, x2, y2 = det["box"]
formatted.append({
"type": det["class_name"],
"size": [int(round(x2 - x1)), int(round(y2 - y1))],
"leftTopPoint": [int(round(x1)), int(round(y1))],
"score": round(det["conf"], 4),
})
return formatted
def get_detection_stats(self):
"""获取检测统计信息"""
return dict(self.stats)
def detect(self, img_path, target_types=None):
"""完整检测流程"""
# 重置统计
self.stats = defaultdict(int)
# 设置目标类别
if target_types:
self.target_ids = [self.type_to_id[cls] for cls in target_types if cls in self.type_to_id]
else:
self.target_ids = None
# 执行检测
primary_dets = self.primary_detect(img_path)
print(f" 初级检测后: {self.stats['primary']} 个目标")
if self.params["enable_secondary"]:
secondary_dets = self.secondary_detect(img_path)
print(f" 次级检测后: {self.stats['secondary']} 个目标")
merged_dets = self.merge_detections(primary_dets, secondary_dets)
print(f" 融合去重后: {self.stats['merged']} 个目标")
else:
merged_dets = primary_dets
print(f" 次级检测未启用")
# 后处理
processed_dets = self.post_process(merged_dets)
print(f" 过滤低置信度后: {self.stats['final']} 个目标")
print(" 最终检测目标详情:")
for idx, det in enumerate(processed_dets, 1):
print(f" 目标{idx} - 类型:{det['class_name']},置信度:{det['conf']:.4f}")
return self.format_results(processed_dets)