import os import tempfile from typing import List from urllib.parse import urlparse import cv2 import requests from pydantic import BaseModel from tqdm import tqdm from ultralytics import YOLO # 定义返回值模型 class DetectionResponse(BaseModel): hasTarget: int originalImgSize: List[int] targets: List[dict] processing_errors: List[str] = [] def download_large_file(url, chunk_size=1024 * 1024): """下载大型文件到临时文件、返回临时文件路径""" try: response = requests.get(url, stream=True, timeout=30) response.raise_for_status() file_size = int(response.headers.get('Content-Length', 0)) temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') # 适配png格式 temp_file_path = temp_file.name temp_file.close() with open(temp_file_path, 'wb') as f, tqdm( total=file_size, unit='B', unit_scale=True, desc=f"下载 {os.path.basename(urlparse(url).path)}" ) as pbar: for chunk in response.iter_content(chunk_size=chunk_size): if chunk: f.write(chunk) pbar.update(len(chunk)) return temp_file_path except Exception as e: error_msg = f"下载失败: {str(e)}" print(error_msg) if 'temp_file_path' in locals(): try: os.remove(temp_file_path) except: pass raise Exception(error_msg) def slice_large_image(image_path, slice_size=1024, overlap=100): """切分大图为切片、返回切片数据和位置信息""" img = cv2.imread(image_path) if img is None: raise ValueError(f"无法读取图像: {image_path}") h, w = img.shape[:2] step = slice_size - overlap num_rows = (h + step - 1) // step num_cols = (w + step - 1) // step slices = [] for i in range(num_rows): for j in range(num_cols): y1 = i * step x1 = j * step y2 = min(y1 + slice_size, h) x2 = min(x1 + slice_size, w) if y2 - y1 < slice_size: y1 = max(0, y2 - slice_size) if x2 - x1 < slice_size: x1 = max(0, x2 - slice_size) slice_img = img[y1:y2, x1:x2] slices.append((x1, y1, slice_img)) return slices, (h, w) def extract_detection_info(result, slice_offset_x, slice_offset_y): """从YOLO OBB结果中提取检测框信息(修正宽高计算)""" detections = [] if result.obb is not None and len(result.obb) > 0: obb_data = result.obb obb_xyxy = obb_data.xyxy.cpu().numpy() classes = obb_data.cls.cpu().numpy() confidences = obb_data.conf.cpu().numpy() for i in range(len(obb_data)): x1_slice, y1_slice, x2_slice, y2_slice = obb_xyxy[i] # 计算实际宽高(x方向为宽,y方向为高) width = x2_slice - x1_slice height = y2_slice - y1_slice # 转换为全局坐标 x1_global = x1_slice + slice_offset_x y1_global = y1_slice + slice_offset_y cls_id = int(classes[i]) confidence = float(confidences[i]) class_name = result.names[cls_id] detection_info = { "type": class_name, "size": [int(round(width)), int(round(height))], "leftTopPoint": [int(round(x1_global)), int(round(y1_global))], "score": round(confidence, 4) } detections.append(detection_info) return detections def detect_large_image_from_url(image_url: str, slice_size: int = 1024, overlap: int = 100) -> DetectionResponse: """ 封装后的检测方法:从图片URL处理大图、返回DetectionResponse对象 """ # 动态拼接固定model_path(当前文件同级目录下) current_dir = os.path.dirname(os.path.abspath(__file__)) model_path = os.path.join(current_dir, "models", "solor_bracket_model", "best.pt") processing_errors = [] all_detections = [] original_size = [0, 0] try: # 验证模型文件是否存在 if not os.path.exists(model_path): raise FileNotFoundError(f"模型文件不存在:{model_path}") # 下载图像 temp_file_path = download_large_file(image_url) try: # 切分图像 slices_info, (h, w) = slice_large_image(temp_file_path, slice_size, overlap) original_size = [w, h] print(f"完成切片: 共 {len(slices_info)} 个切片") # 加载模型并预测 model = YOLO(model_path) print("开始逐张预测切片...") for i, (x1, y1, slice_img) in enumerate(slices_info, 1): print(f"预测第 {i}/{len(slices_info)} 个切片") result = model(slice_img, conf=0.5, verbose=False)[0] slice_detections = extract_detection_info(result, x1, y1) all_detections.extend(slice_detections) print(f" 本切片检测到 {len(slice_detections)} 个目标") finally: # 确保临时文件删除 if os.path.exists(temp_file_path): try: os.remove(temp_file_path) print("临时文件已删除") except Exception as e: error_msg = f"删除临时文件失败: {str(e)}" print(error_msg) processing_errors.append(error_msg) except Exception as e: # 捕获所有异常并记录 error_msg = str(e) processing_errors.append(error_msg) print(f"处理异常: {error_msg}") # 构建并返回DetectionResponse对象 return DetectionResponse( hasTarget=1 if len(all_detections) > 0 else 0, originalImgSize=original_size, targets=all_detections, processing_errors=processing_errors )