import os from collections import defaultdict import cv2 import numpy as np from sahi import AutoDetectionModel from sahi.predict import get_sliced_prediction from ultralytics import YOLO from config import DEVICE, DEFAULT_IOU, DEFAULT_MIN_SIZE, DEFAULT_POS_THRESH, SLICE_RULES, DEFAULT_CONF class YOLODetector: def __init__(self, model_path, params, type_to_id): # 加载YOLO模型 self.model = YOLO(model_path) self.model.to(DEVICE) self.class_names = self.model.names self.type_to_id = type_to_id self.params = params self.enable_primary = params.get("enable_primary", True) self.primary_conf = params.get("primary_conf", DEFAULT_CONF) # 初级检测阈值 self.secondary_conf = params.get("secondary_conf", DEFAULT_CONF) # 次级检测阈值 self.final_conf = params.get("final_conf", DEFAULT_CONF) # 最终展示阈值 # SAHI模型 self.sahi_model = None if params["enable_secondary"]: self.sahi_model = AutoDetectionModel.from_pretrained( model_type='yolov8', model_path=model_path, confidence_threshold=self.secondary_conf, device=DEVICE ) # 统计 self.stats = defaultdict(int) def get_adaptive_slice(self, total_pixels): """自适应切片参数""" for pixel_thresh, (size, overlap) in SLICE_RULES: if total_pixels > pixel_thresh: return size, overlap return self.params["slice_size"], self.params["overlap_ratio"] def multi_scale_detect(self, img_path): """多尺度检测(使用模型专属初级阈值)""" detections = [] img = cv2.imread(img_path) h, w = img.shape[:2] for scale in self.params["multi_scales"]: if scale == 1.0: # 原尺度检测 results = self.model( img_path, conf=self.primary_conf, # 模型专属初级阈值 device=DEVICE, classes=self.target_ids, verbose=False ) else: # 缩放检测 nw, nh = int(w * scale), int(h * scale) scaled_img = cv2.resize(img, (nw, nh)) temp_path = f"temp_scale_{scale}.jpg" cv2.imwrite(temp_path, scaled_img) results = self.model( temp_path, conf=self.primary_conf, # 模型专属初级阈值 device=DEVICE, classes=self.target_ids, verbose=False ) os.remove(temp_path) # 解析结果(核心修复:增加对result.boxes为None的判断) for result in results: # 检查boxes是否存在且非空 if result.boxes is None: continue for box in result.boxes: bbox = box.xyxy[0].tolist() if scale != 1.0: bbox = [coord / scale for coord in bbox] detections.append({ "box": bbox, "conf": box.conf[0].item(), "class": box.cls[0].item(), "class_name": self.class_names[int(box.cls[0])], "source": "primary" }) return detections def primary_detect(self, img_path): """初次检测(使用模型专属初级阈值)- 新增enable_primary判断""" # 新增:如果禁用一级检测,直接返回空列表 if not self.enable_primary: self.stats["primary"] = 0 print(" 一级检测已禁用,跳过初级检测") return [] if self.params["enable_multi_scale"]: detections = self.multi_scale_detect(img_path) else: results = self.model( img_path, conf=self.primary_conf, # 模型专属初级阈值 device=DEVICE, classes=self.target_ids, verbose=False ) # 解析结果(核心修复:增加对result.boxes为None的判断) detections = [] for result in results: # 检查boxes是否存在且非空 if result.boxes is None: continue for box in result.boxes: detections.append({ "box": box.xyxy[0].tolist(), "conf": box.conf[0].item(), "class": box.cls[0].item(), "class_name": self.class_names[int(box.cls[0])], "source": "primary" }) self.stats["primary"] = len(detections) return detections def secondary_detect(self, img_path): """SAHI切片检测(已在初始化时使用模型专属次级阈值)""" if not self.params["enable_secondary"] or not self.sahi_model: return [] img = cv2.imread(img_path) h, w = img.shape[:2] total_pixels = w * h slice_size, overlap = self.get_adaptive_slice(total_pixels) # SAHI切片预测 sliced_results = get_sliced_prediction( img_path, self.sahi_model, slice_height=slice_size, slice_width=slice_size, overlap_height_ratio=overlap, overlap_width_ratio=overlap, verbose=0 ) detections = [] for obj in sliced_results.object_prediction_list: if self.target_ids and obj.category.id not in self.target_ids: continue bbox = obj.bbox.to_xyxy() bw, bh = bbox[2] - bbox[0], bbox[3] - bbox[1] if bw >= DEFAULT_MIN_SIZE and bh >= DEFAULT_MIN_SIZE: detections.append({ "box": bbox, "conf": obj.score.value, "class": obj.category.id, "class_name": obj.category.name, "source": "secondary" }) self.stats["secondary"] = len(detections) return detections @staticmethod def calculate_iou(box1, box2): """计算IoU""" x11, y11, x21, y21 = box1 x12, y12, x22, y22 = box2 inter_x1 = max(x11, x12) inter_y1 = max(y11, y12) inter_x2 = min(x21, x22) inter_y2 = min(y21, y22) inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1) area1 = (x21 - x11) * (y21 - y11) area2 = (x22 - x12) * (y22 - y12) union_area = area1 + area2 - inter_area return inter_area / union_area if union_area > 0 else 0 def merge_detections(self, primary_dets, secondary_dets): """融合检测结果""" if not primary_dets: return secondary_dets if not secondary_dets: return primary_dets # 加权置信度 all_dets = [] for det in primary_dets: det["weighted_conf"] = det["conf"] * self.params["weight_primary"] all_dets.append(det) for det in secondary_dets: det["weighted_conf"] = det["conf"] * self.params["weight_secondary"] all_dets.append(det) # 按类别分组融合 class_groups = defaultdict(list) for det in all_dets: class_groups[det["class"]].append(det) merged = [] for cls_id, cls_dets in class_groups.items(): cls_dets.sort(key=lambda x: x["weighted_conf"], reverse=True) suppressed = [False] * len(cls_dets) for i in range(len(cls_dets)): if suppressed[i]: continue merged.append(cls_dets[i]) for j in range(i + 1, len(cls_dets)): if not suppressed[j] and self.calculate_iou(cls_dets[i]["box"], cls_dets[j]["box"]) > DEFAULT_IOU: suppressed[j] = True self.stats["merged"] = len(merged) return merged def post_process(self, detections): """后处理(使用模型专属最终阈值)""" # 置信度过滤:模型专属最终阈值 filtered = [det for det in detections if det["conf"] >= self.final_conf] # 位置去重 final_dets = [] for curr_det in filtered: curr_cx = (curr_det["box"][0] + curr_det["box"][2]) / 2 curr_cy = (curr_det["box"][1] + curr_det["box"][3]) / 2 curr_cls = curr_det["class"] duplicate = False for idx, exist_det in enumerate(final_dets): if exist_det["class"] != curr_cls: continue exist_cx = (exist_det["box"][0] + exist_det["box"][2]) / 2 exist_cy = (exist_det["box"][1] + exist_det["box"][3]) / 2 dist = np.sqrt((curr_cx - exist_cx) **2 + (curr_cy - exist_cy)** 2) if dist < DEFAULT_POS_THRESH: duplicate = True if curr_det["conf"] > exist_det["conf"]: final_dets[idx] = curr_det break if not duplicate: final_dets.append(curr_det) self.stats["final"] = len(final_dets) return final_dets def format_results(self, detections): """格式化结果""" formatted = [] for det in detections: x1, y1, x2, y2 = det["box"] formatted.append({ "type": det["class_name"], "size": [int(round(x2 - x1)), int(round(y2 - y1))], "leftTopPoint": [int(round(x1)), int(round(y1))], "score": round(det["conf"], 4), }) return formatted def get_detection_stats(self): """获取检测统计信息""" return dict(self.stats) def detect(self, img_path, target_types=None): """完整检测流程""" # 重置统计 self.stats = defaultdict(int) # 设置目标类别 if target_types: self.target_ids = [self.type_to_id[cls] for cls in target_types if cls in self.type_to_id] else: self.target_ids = None # 执行检测 primary_dets = self.primary_detect(img_path) print(f" 初级检测后: {self.stats['primary']} 个目标") if self.params["enable_secondary"]: secondary_dets = self.secondary_detect(img_path) print(f" 次级检测后: {self.stats['secondary']} 个目标") merged_dets = self.merge_detections(primary_dets, secondary_dets) print(f" 融合去重后: {self.stats['merged']} 个目标") else: merged_dets = primary_dets print(f" 次级检测未启用") # 后处理 processed_dets = self.post_process(merged_dets) print(f" 过滤低置信度后: {self.stats['final']} 个目标") print(" 最终检测目标详情:") for idx, det in enumerate(processed_dets, 1): print(f" 目标{idx} - 类型:{det['class_name']},置信度:{det['conf']:.4f}") return self.format_results(processed_dets)