Files
AlgorithmCollection/process.py
2025-12-02 16:43:56 +08:00

181 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import tempfile
from typing import List
from urllib.parse import urlparse
import cv2
import requests
from pydantic import BaseModel
from tqdm import tqdm
from ultralytics import YOLO
# 定义返回值模型
class DetectionResponse(BaseModel):
hasTarget: int
originalImgSize: List[int]
targets: List[dict]
processing_errors: List[str] = []
def download_large_file(url, chunk_size=1024 * 1024):
"""下载大型文件到临时文件、返回临时文件路径"""
try:
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
file_size = int(response.headers.get('Content-Length', 0))
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') # 适配png格式
temp_file_path = temp_file.name
temp_file.close()
with open(temp_file_path, 'wb') as f, tqdm(
total=file_size, unit='B', unit_scale=True,
desc=f"下载 {os.path.basename(urlparse(url).path)}"
) as pbar:
for chunk in response.iter_content(chunk_size=chunk_size):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
return temp_file_path
except Exception as e:
error_msg = f"下载失败: {str(e)}"
print(error_msg)
if 'temp_file_path' in locals():
try:
os.remove(temp_file_path)
except:
pass
raise Exception(error_msg)
def slice_large_image(image_path, slice_size=1024, overlap=100):
"""切分大图为切片、返回切片数据和位置信息"""
img = cv2.imread(image_path)
if img is None:
raise ValueError(f"无法读取图像: {image_path}")
h, w = img.shape[:2]
step = slice_size - overlap
num_rows = (h + step - 1) // step
num_cols = (w + step - 1) // step
slices = []
for i in range(num_rows):
for j in range(num_cols):
y1 = i * step
x1 = j * step
y2 = min(y1 + slice_size, h)
x2 = min(x1 + slice_size, w)
if y2 - y1 < slice_size:
y1 = max(0, y2 - slice_size)
if x2 - x1 < slice_size:
x1 = max(0, x2 - slice_size)
slice_img = img[y1:y2, x1:x2]
slices.append((x1, y1, slice_img))
return slices, (h, w)
def extract_detection_info(result, slice_offset_x, slice_offset_y):
"""从YOLO OBB结果中提取检测框信息修正宽高计算"""
detections = []
if result.obb is not None and len(result.obb) > 0:
obb_data = result.obb
obb_xyxy = obb_data.xyxy.cpu().numpy()
classes = obb_data.cls.cpu().numpy()
confidences = obb_data.conf.cpu().numpy()
for i in range(len(obb_data)):
x1_slice, y1_slice, x2_slice, y2_slice = obb_xyxy[i]
# 计算实际宽高x方向为宽y方向为高
width = x2_slice - x1_slice
height = y2_slice - y1_slice
# 转换为全局坐标
x1_global = x1_slice + slice_offset_x
y1_global = y1_slice + slice_offset_y
cls_id = int(classes[i])
confidence = float(confidences[i])
class_name = result.names[cls_id]
detection_info = {
"type": class_name,
"size": [int(round(width)), int(round(height))],
"leftTopPoint": [int(round(x1_global)), int(round(y1_global))],
"score": round(confidence, 4)
}
detections.append(detection_info)
return detections
def detect_large_image_from_url(image_url: str, slice_size: int = 1024, overlap: int = 100) -> DetectionResponse:
"""
封装后的检测方法从图片URL处理大图、返回DetectionResponse对象
"""
# 动态拼接固定model_path当前文件同级目录下
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, "models", "solor_bracket_model", "best.pt")
processing_errors = []
all_detections = []
original_size = [0, 0]
try:
# 验证模型文件是否存在
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在:{model_path}")
# 下载图像
temp_file_path = download_large_file(image_url)
try:
# 切分图像
slices_info, (h, w) = slice_large_image(temp_file_path, slice_size, overlap)
original_size = [w, h]
print(f"完成切片: 共 {len(slices_info)} 个切片")
# 加载模型并预测
model = YOLO(model_path)
print("开始逐张预测切片...")
for i, (x1, y1, slice_img) in enumerate(slices_info, 1):
print(f"预测第 {i}/{len(slices_info)} 个切片")
result = model(slice_img, conf=0.5, verbose=False)[0]
slice_detections = extract_detection_info(result, x1, y1)
all_detections.extend(slice_detections)
print(f" 本切片检测到 {len(slice_detections)} 个目标")
finally:
# 确保临时文件删除
if os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
print("临时文件已删除")
except Exception as e:
error_msg = f"删除临时文件失败: {str(e)}"
print(error_msg)
processing_errors.append(error_msg)
except Exception as e:
# 捕获所有异常并记录
error_msg = str(e)
processing_errors.append(error_msg)
print(f"处理异常: {error_msg}")
# 构建并返回DetectionResponse对象
return DetectionResponse(
hasTarget=1 if len(all_detections) > 0 else 0,
originalImgSize=original_size,
targets=all_detections,
processing_errors=processing_errors
)