commit 0ac74b68926f9adaa7a3b5bdfe4a303dbc278cbe Author: ZZX9599 <536509593@qq.com> Date: Tue Dec 2 16:43:56 2025 +0800 中煤科工算法合集 diff --git a/config.py b/config.py new file mode 100644 index 0000000..d96a5a2 --- /dev/null +++ b/config.py @@ -0,0 +1,66 @@ +from typing import List + +import torch +from pydantic import BaseModel + +# 设备配置 +DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" + +# 默认检测参数 +DEFAULT_CONF = 0.25 +DEFAULT_IOU = 0.5 +DEFAULT_MIN_SIZE = 8 +DEFAULT_POS_THRESH = 5 + +MODEL_CONFIGS = { + "安全施工模型": { + "model_path": "models/ppe_state_model/best.pt", + "types": ["novest", "nohelmet"], + "type_to_id": {"novest": 0, "nohelmet": 2}, + "params": { + "enable_primary": True, + "primary_conf": 0.55, + "secondary_conf": 0.6, + "final_conf": 0.65, + "enable_multi_scale": True, + "multi_scales": [0.75, 1.0, 1.25], + "enable_secondary": True, + "slice_size": 512, + "overlap_ratio": 0.3, + "weight_primary": 0.4, + "weight_secondary": 0.6 + } + }, + "烟雾火灾模型": { + "model_path": "models/fire_smoke_model/best.pt", + "types": ["fire", "smoke"], + "type_to_id": {"fire": 0, "smoke": 1}, + "params": { + "enable_primary": True, + "primary_conf": 0.99, + "secondary_conf": 0.99, + "final_conf": 0.99, + "enable_multi_scale": True, + "multi_scales": [0.75, 1.0, 1.25], + "enable_secondary": True, + "slice_size": 512, + "overlap_ratio": 0.3, + "weight_primary": 0.4, + "weight_secondary": 0.6 + } + } +} + +# SAHI自适应切片配置 +SLICE_RULES = [ + (12_000_000, (384, 0.35)), + (3_000_000, (512, 0.3)), + (0, (640, 0.25)) +] + + +class DetectionResponse(BaseModel): + hasTarget: int + originalImgSize: List[int] + targets: List[dict] + processing_errors: List[str] = [] diff --git a/detect.py b/detect.py new file mode 100644 index 0000000..878cd9c --- /dev/null +++ b/detect.py @@ -0,0 +1,313 @@ +import os +from collections import defaultdict + +import cv2 +import numpy as np +from sahi import AutoDetectionModel +from sahi.predict import get_sliced_prediction +from ultralytics import YOLO + +from config import DEVICE, DEFAULT_IOU, DEFAULT_MIN_SIZE, DEFAULT_POS_THRESH, SLICE_RULES, DEFAULT_CONF + + +class YOLODetector: + def __init__(self, model_path, params, type_to_id): + # 加载YOLO模型 + self.model = YOLO(model_path) + self.model.to(DEVICE) + self.class_names = self.model.names + self.type_to_id = type_to_id + + self.params = params + self.enable_primary = params.get("enable_primary", True) + self.primary_conf = params.get("primary_conf", DEFAULT_CONF) # 初级检测阈值 + self.secondary_conf = params.get("secondary_conf", DEFAULT_CONF) # 次级检测阈值 + self.final_conf = params.get("final_conf", DEFAULT_CONF) # 最终展示阈值 + + # SAHI模型 + self.sahi_model = None + if params["enable_secondary"]: + self.sahi_model = AutoDetectionModel.from_pretrained( + model_type='yolov8', + model_path=model_path, + confidence_threshold=self.secondary_conf, + device=DEVICE + ) + + # 统计 + self.stats = defaultdict(int) + + def get_adaptive_slice(self, total_pixels): + """自适应切片参数""" + for pixel_thresh, (size, overlap) in SLICE_RULES: + if total_pixels > pixel_thresh: + return size, overlap + return self.params["slice_size"], self.params["overlap_ratio"] + + def multi_scale_detect(self, img_path): + """多尺度检测(使用模型专属初级阈值)""" + detections = [] + img = cv2.imread(img_path) + h, w = img.shape[:2] + + for scale in self.params["multi_scales"]: + if scale == 1.0: + # 原尺度检测 + results = self.model( + img_path, + conf=self.primary_conf, # 模型专属初级阈值 + device=DEVICE, + classes=self.target_ids, + verbose=False + ) + else: + # 缩放检测 + nw, nh = int(w * scale), int(h * scale) + scaled_img = cv2.resize(img, (nw, nh)) + temp_path = f"temp_scale_{scale}.jpg" + cv2.imwrite(temp_path, scaled_img) + + results = self.model( + temp_path, + conf=self.primary_conf, # 模型专属初级阈值 + device=DEVICE, + classes=self.target_ids, + verbose=False + ) + os.remove(temp_path) + + # 解析结果(核心修复:增加对result.boxes为None的判断) + for result in results: + # 检查boxes是否存在且非空 + if result.boxes is None: + continue + for box in result.boxes: + bbox = box.xyxy[0].tolist() + if scale != 1.0: + bbox = [coord / scale for coord in bbox] + + detections.append({ + "box": bbox, + "conf": box.conf[0].item(), + "class": box.cls[0].item(), + "class_name": self.class_names[int(box.cls[0])], + "source": "primary" + }) + + return detections + + def primary_detect(self, img_path): + """初次检测(使用模型专属初级阈值)- 新增enable_primary判断""" + # 新增:如果禁用一级检测,直接返回空列表 + if not self.enable_primary: + self.stats["primary"] = 0 + print(" 一级检测已禁用,跳过初级检测") + return [] + + if self.params["enable_multi_scale"]: + detections = self.multi_scale_detect(img_path) + else: + results = self.model( + img_path, + conf=self.primary_conf, # 模型专属初级阈值 + device=DEVICE, + classes=self.target_ids, + verbose=False + ) + # 解析结果(核心修复:增加对result.boxes为None的判断) + detections = [] + for result in results: + # 检查boxes是否存在且非空 + if result.boxes is None: + continue + for box in result.boxes: + detections.append({ + "box": box.xyxy[0].tolist(), + "conf": box.conf[0].item(), + "class": box.cls[0].item(), + "class_name": self.class_names[int(box.cls[0])], + "source": "primary" + }) + + self.stats["primary"] = len(detections) + return detections + + def secondary_detect(self, img_path): + """SAHI切片检测(已在初始化时使用模型专属次级阈值)""" + if not self.params["enable_secondary"] or not self.sahi_model: + return [] + + img = cv2.imread(img_path) + h, w = img.shape[:2] + total_pixels = w * h + slice_size, overlap = self.get_adaptive_slice(total_pixels) + + # SAHI切片预测 + sliced_results = get_sliced_prediction( + img_path, + self.sahi_model, + slice_height=slice_size, + slice_width=slice_size, + overlap_height_ratio=overlap, + overlap_width_ratio=overlap, + verbose=0 + ) + + detections = [] + for obj in sliced_results.object_prediction_list: + if self.target_ids and obj.category.id not in self.target_ids: + continue + + bbox = obj.bbox.to_xyxy() + bw, bh = bbox[2] - bbox[0], bbox[3] - bbox[1] + + if bw >= DEFAULT_MIN_SIZE and bh >= DEFAULT_MIN_SIZE: + detections.append({ + "box": bbox, + "conf": obj.score.value, + "class": obj.category.id, + "class_name": obj.category.name, + "source": "secondary" + }) + + self.stats["secondary"] = len(detections) + return detections + + @staticmethod + def calculate_iou(box1, box2): + """计算IoU""" + x11, y11, x21, y21 = box1 + x12, y12, x22, y22 = box2 + + inter_x1 = max(x11, x12) + inter_y1 = max(y11, y12) + inter_x2 = min(x21, x22) + inter_y2 = min(y21, y22) + + inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1) + area1 = (x21 - x11) * (y21 - y11) + area2 = (x22 - x12) * (y22 - y12) + union_area = area1 + area2 - inter_area + + return inter_area / union_area if union_area > 0 else 0 + + def merge_detections(self, primary_dets, secondary_dets): + """融合检测结果""" + if not primary_dets: + return secondary_dets + if not secondary_dets: + return primary_dets + + # 加权置信度 + all_dets = [] + for det in primary_dets: + det["weighted_conf"] = det["conf"] * self.params["weight_primary"] + all_dets.append(det) + for det in secondary_dets: + det["weighted_conf"] = det["conf"] * self.params["weight_secondary"] + all_dets.append(det) + + # 按类别分组融合 + class_groups = defaultdict(list) + for det in all_dets: + class_groups[det["class"]].append(det) + + merged = [] + for cls_id, cls_dets in class_groups.items(): + cls_dets.sort(key=lambda x: x["weighted_conf"], reverse=True) + suppressed = [False] * len(cls_dets) + + for i in range(len(cls_dets)): + if suppressed[i]: + continue + merged.append(cls_dets[i]) + for j in range(i + 1, len(cls_dets)): + if not suppressed[j] and self.calculate_iou(cls_dets[i]["box"], cls_dets[j]["box"]) > DEFAULT_IOU: + suppressed[j] = True + + self.stats["merged"] = len(merged) + return merged + + def post_process(self, detections): + """后处理(使用模型专属最终阈值)""" + # 置信度过滤:模型专属最终阈值 + filtered = [det for det in detections if det["conf"] >= self.final_conf] + + # 位置去重 + final_dets = [] + for curr_det in filtered: + curr_cx = (curr_det["box"][0] + curr_det["box"][2]) / 2 + curr_cy = (curr_det["box"][1] + curr_det["box"][3]) / 2 + curr_cls = curr_det["class"] + duplicate = False + + for idx, exist_det in enumerate(final_dets): + if exist_det["class"] != curr_cls: + continue + + exist_cx = (exist_det["box"][0] + exist_det["box"][2]) / 2 + exist_cy = (exist_det["box"][1] + exist_det["box"][3]) / 2 + dist = np.sqrt((curr_cx - exist_cx) **2 + (curr_cy - exist_cy)** 2) + + if dist < DEFAULT_POS_THRESH: + duplicate = True + if curr_det["conf"] > exist_det["conf"]: + final_dets[idx] = curr_det + break + + if not duplicate: + final_dets.append(curr_det) + + self.stats["final"] = len(final_dets) + return final_dets + + def format_results(self, detections): + """格式化结果""" + formatted = [] + for det in detections: + x1, y1, x2, y2 = det["box"] + formatted.append({ + "type": det["class_name"], + "size": [int(round(x2 - x1)), int(round(y2 - y1))], + "leftTopPoint": [int(round(x1)), int(round(y1))], + "score": round(det["conf"], 4), + }) + return formatted + + def get_detection_stats(self): + """获取检测统计信息""" + return dict(self.stats) + + def detect(self, img_path, target_types=None): + """完整检测流程""" + # 重置统计 + self.stats = defaultdict(int) + + # 设置目标类别 + if target_types: + self.target_ids = [self.type_to_id[cls] for cls in target_types if cls in self.type_to_id] + else: + self.target_ids = None + + # 执行检测 + primary_dets = self.primary_detect(img_path) + print(f" 初级检测后: {self.stats['primary']} 个目标") + + if self.params["enable_secondary"]: + secondary_dets = self.secondary_detect(img_path) + print(f" 次级检测后: {self.stats['secondary']} 个目标") + merged_dets = self.merge_detections(primary_dets, secondary_dets) + print(f" 融合去重后: {self.stats['merged']} 个目标") + else: + merged_dets = primary_dets + print(f" 次级检测未启用") + + # 后处理 + processed_dets = self.post_process(merged_dets) + print(f" 过滤低置信度后: {self.stats['final']} 个目标") + + print(" 最终检测目标详情:") + for idx, det in enumerate(processed_dets, 1): + print(f" 目标{idx} - 类型:{det['class_name']},置信度:{det['conf']:.4f}") + + return self.format_results(processed_dets) \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..2a9e36f --- /dev/null +++ b/main.py @@ -0,0 +1,126 @@ +import io +import os +import tempfile +from contextlib import asynccontextmanager + +import requests +import uvicorn +from PIL import Image +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, HttpUrl + +from config import DetectionResponse +from process import detect_large_image_from_url + +# 全局检测管理器 +detector_manager = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global detector_manager + try: + from manager import UnifiedDetectionManager + detector_manager = UnifiedDetectionManager() + print("检测管理器初始化成功") + except Exception as e: + print(f"初始化失败:{str(e)}") + raise + yield + + +app = FastAPI(lifespan=lifespan, title="目标检测API", version="1.0.0") + +# 配置跨域请求 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +class DetectionRequest(BaseModel): + type: str + url: HttpUrl + + +class DetectionProcessRequest(BaseModel): + url: HttpUrl + + +@app.post("/detect_image", response_model=DetectionResponse) +async def run_detection_image(request: DetectionRequest): + # 解析检测类型 + requested_types = {t.strip().lower() for t in request.type.split(',') if t.strip()} + print(f"请求的检测类型: {requested_types}") + if not requested_types: + raise HTTPException(status_code=400, detail="未指定检测类型") + + # 下载图片 + try: + response = requests.get(str(request.url), timeout=15) + response.raise_for_status() + + # 获取图片尺寸 + with Image.open(io.BytesIO(response.content)) as img: + img_size = [img.width, img.height] + + # 创建临时文件 + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file: + temp_file.write(response.content) + temp_path = temp_file.name + + except Exception as e: + raise HTTPException(status_code=400, detail=f"图片处理失败:{str(e)}") + + # 执行检测 + results = [] + errors = [] + + try: + detection_results = detector_manager.detect(temp_path, ",".join(requested_types)) + if detection_results: + results = detection_results + except Exception as e: + errors.append(f"检测失111111111111111111111败:{str(e)}") + finally: + # 清理临时文件 + if os.path.exists(temp_path): + os.remove(temp_path) + + return { + "hasTarget": 1 if results else 0, + "originalImgSize": img_size, + "targets": results, + "processing_errors": errors + } + + +@app.post("/detect_process", response_model=DetectionResponse) +async def run_detection_process(request: DetectionProcessRequest): + return detect_large_image_from_url(str(request.url)) + + +@app.get("/supported_types") +async def get_supported_types(): + if detector_manager: + info = detector_manager.get_available_info() + return { + "supported_types": info["supported_types"], + } + return {"supported_types": []} + + +if __name__ == "__main__": + + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--reload", action="store_true") + args = parser.parse_args() + uvicorn.run("main:app", host=args.host, port=args.port, reload=args.reload) diff --git a/manager.py b/manager.py new file mode 100644 index 0000000..fb76b27 --- /dev/null +++ b/manager.py @@ -0,0 +1,116 @@ +import os +from collections import defaultdict + +from config import MODEL_CONFIGS +from detect import YOLODetector + + +class UnifiedDetectionManager: + """统一检测管理器""" + + def __init__(self): + self.detectors = {} # 检测器实例 + self.type_to_model = {} # 类别到模型映射 + self.loaded_models = [] # 已加载模型 + self.type_to_id = {} # 全局类别ID映射 + + self._load_models() + + def _load_models(self): + """加载所有模型""" + if not MODEL_CONFIGS: + raise ValueError("模型配置为空") + + for model_name, config in MODEL_CONFIGS.items(): + try: + model_path = config["model_path"] + if not os.path.exists(model_path): + print(f"跳过 {model_name}: 模型文件不存在 - {model_path}") + continue + + # 创建检测器(自动传递新增的enable_primary配置) + detector = YOLODetector( + model_path=model_path, + params=config["params"], + type_to_id=config["type_to_id"] + ) + + # 保存状态 + self.detectors[model_name] = detector + self.loaded_models.append(model_name) + + # 建立映射 + for det_type in config["types"]: + det_type_lower = det_type.lower() + if det_type_lower in self.type_to_model: + print(f"警告: 类别 '{det_type}' 映射冲突") + self.type_to_model[det_type_lower] = model_name + self.type_to_id[det_type_lower] = config["type_to_id"][det_type_lower] + + print(f"加载成功: {model_name}") + + except Exception as e: + print(f"加载失败 {model_name}: {str(e)}") + continue + + print(f"模型加载完成: {len(self.loaded_models)}/{len(MODEL_CONFIGS)}") + print(f"支持类别: {list(self.type_to_model.keys())}") + + def parse_types(self, types_str): + """解析检测类型""" + if not types_str: + raise ValueError("检测类型为空") + + # 清理输入 + requested_types = list(set(t.strip().lower() for t in types_str.split(',') if t.strip())) + + # 按模型分组 + model_type_map = defaultdict(list) + for det_type in requested_types: + if det_type in self.type_to_model: + model_name = self.type_to_model[det_type] + model_type_map[model_name].append(det_type) + else: + print(f"忽略未知类别: {det_type}") + + if not model_type_map: + raise ValueError("无有效检测类别") + + return model_type_map + + def detect(self, img_path, detection_types): + """执行检测""" + if not os.path.exists(img_path): + raise FileNotFoundError(f"图像不存在: {img_path}") + + # 解析类型 + model_type_map = self.parse_types(detection_types) + + # 执行检测(自动适配enable_primary配置) + all_results = [] + for model_name, target_types in model_type_map.items(): + if model_name not in self.detectors: + continue + + print(f"检测: {model_name} -> {target_types}") + try: + results = self.detectors[model_name].detect(img_path, target_types) + all_results.extend(results) + + # 获取详细统计信息 + stats = self.detectors[model_name].get_detection_stats() + print(f" {model_name}详细统计: {stats}") + + except Exception as e: + print(f"检测失败 {model_name}: {str(e)}") + + print(f"检测完成: 总共 {len(all_results)} 个结果") + return all_results + + def get_available_info(self): + """获取可用信息""" + return { + "loaded_models": self.loaded_models, + "supported_types": list(self.type_to_model.keys()), + "type_to_model": self.type_to_model + } \ No newline at end of file diff --git a/process.py b/process.py new file mode 100644 index 0000000..1383e3c --- /dev/null +++ b/process.py @@ -0,0 +1,180 @@ +import os +import tempfile +from typing import List +from urllib.parse import urlparse + +import cv2 +import requests +from pydantic import BaseModel +from tqdm import tqdm +from ultralytics import YOLO + + +# 定义返回值模型 +class DetectionResponse(BaseModel): + hasTarget: int + originalImgSize: List[int] + targets: List[dict] + processing_errors: List[str] = [] + + +def download_large_file(url, chunk_size=1024 * 1024): + """下载大型文件到临时文件、返回临时文件路径""" + try: + response = requests.get(url, stream=True, timeout=30) + response.raise_for_status() + + file_size = int(response.headers.get('Content-Length', 0)) + + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') # 适配png格式 + temp_file_path = temp_file.name + temp_file.close() + + with open(temp_file_path, 'wb') as f, tqdm( + total=file_size, unit='B', unit_scale=True, + desc=f"下载 {os.path.basename(urlparse(url).path)}" + ) as pbar: + for chunk in response.iter_content(chunk_size=chunk_size): + if chunk: + f.write(chunk) + pbar.update(len(chunk)) + + return temp_file_path + + except Exception as e: + error_msg = f"下载失败: {str(e)}" + print(error_msg) + if 'temp_file_path' in locals(): + try: + os.remove(temp_file_path) + except: + pass + raise Exception(error_msg) + + +def slice_large_image(image_path, slice_size=1024, overlap=100): + """切分大图为切片、返回切片数据和位置信息""" + img = cv2.imread(image_path) + if img is None: + raise ValueError(f"无法读取图像: {image_path}") + + h, w = img.shape[:2] + step = slice_size - overlap + num_rows = (h + step - 1) // step + num_cols = (w + step - 1) // step + + slices = [] + for i in range(num_rows): + for j in range(num_cols): + y1 = i * step + x1 = j * step + y2 = min(y1 + slice_size, h) + x2 = min(x1 + slice_size, w) + + if y2 - y1 < slice_size: + y1 = max(0, y2 - slice_size) + if x2 - x1 < slice_size: + x1 = max(0, x2 - slice_size) + + slice_img = img[y1:y2, x1:x2] + slices.append((x1, y1, slice_img)) + + return slices, (h, w) + + +def extract_detection_info(result, slice_offset_x, slice_offset_y): + """从YOLO OBB结果中提取检测框信息(修正宽高计算)""" + detections = [] + + if result.obb is not None and len(result.obb) > 0: + obb_data = result.obb + obb_xyxy = obb_data.xyxy.cpu().numpy() + classes = obb_data.cls.cpu().numpy() + confidences = obb_data.conf.cpu().numpy() + + for i in range(len(obb_data)): + x1_slice, y1_slice, x2_slice, y2_slice = obb_xyxy[i] + # 计算实际宽高(x方向为宽,y方向为高) + width = x2_slice - x1_slice + height = y2_slice - y1_slice + + # 转换为全局坐标 + x1_global = x1_slice + slice_offset_x + y1_global = y1_slice + slice_offset_y + + cls_id = int(classes[i]) + confidence = float(confidences[i]) + class_name = result.names[cls_id] + + detection_info = { + "type": class_name, + "size": [int(round(width)), int(round(height))], + "leftTopPoint": [int(round(x1_global)), int(round(y1_global))], + "score": round(confidence, 4) + } + detections.append(detection_info) + + return detections + + +def detect_large_image_from_url(image_url: str, slice_size: int = 1024, overlap: int = 100) -> DetectionResponse: + """ + 封装后的检测方法:从图片URL处理大图、返回DetectionResponse对象 + """ + # 动态拼接固定model_path(当前文件同级目录下) + current_dir = os.path.dirname(os.path.abspath(__file__)) + model_path = os.path.join(current_dir, "models", "solor_bracket_model", "best.pt") + + processing_errors = [] + all_detections = [] + original_size = [0, 0] + + try: + # 验证模型文件是否存在 + if not os.path.exists(model_path): + raise FileNotFoundError(f"模型文件不存在:{model_path}") + + # 下载图像 + temp_file_path = download_large_file(image_url) + + try: + # 切分图像 + slices_info, (h, w) = slice_large_image(temp_file_path, slice_size, overlap) + original_size = [w, h] + print(f"完成切片: 共 {len(slices_info)} 个切片") + + # 加载模型并预测 + model = YOLO(model_path) + print("开始逐张预测切片...") + + for i, (x1, y1, slice_img) in enumerate(slices_info, 1): + print(f"预测第 {i}/{len(slices_info)} 个切片") + result = model(slice_img, conf=0.5, verbose=False)[0] + slice_detections = extract_detection_info(result, x1, y1) + all_detections.extend(slice_detections) + print(f" 本切片检测到 {len(slice_detections)} 个目标") + + finally: + # 确保临时文件删除 + if os.path.exists(temp_file_path): + try: + os.remove(temp_file_path) + print("临时文件已删除") + except Exception as e: + error_msg = f"删除临时文件失败: {str(e)}" + print(error_msg) + processing_errors.append(error_msg) + + except Exception as e: + # 捕获所有异常并记录 + error_msg = str(e) + processing_errors.append(error_msg) + print(f"处理异常: {error_msg}") + + # 构建并返回DetectionResponse对象 + return DetectionResponse( + hasTarget=1 if len(all_detections) > 0 else 0, + originalImgSize=original_size, + targets=all_detections, + processing_errors=processing_errors + )