181 lines
5.9 KiB
Python
181 lines
5.9 KiB
Python
|
|
import os
|
|||
|
|
import tempfile
|
|||
|
|
from typing import List
|
|||
|
|
from urllib.parse import urlparse
|
|||
|
|
|
|||
|
|
import cv2
|
|||
|
|
import requests
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
from tqdm import tqdm
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 定义返回值模型
|
|||
|
|
class DetectionResponse(BaseModel):
|
|||
|
|
hasTarget: int
|
|||
|
|
originalImgSize: List[int]
|
|||
|
|
targets: List[dict]
|
|||
|
|
processing_errors: List[str] = []
|
|||
|
|
|
|||
|
|
|
|||
|
|
def download_large_file(url, chunk_size=1024 * 1024):
|
|||
|
|
"""下载大型文件到临时文件、返回临时文件路径"""
|
|||
|
|
try:
|
|||
|
|
response = requests.get(url, stream=True, timeout=30)
|
|||
|
|
response.raise_for_status()
|
|||
|
|
|
|||
|
|
file_size = int(response.headers.get('Content-Length', 0))
|
|||
|
|
|
|||
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') # 适配png格式
|
|||
|
|
temp_file_path = temp_file.name
|
|||
|
|
temp_file.close()
|
|||
|
|
|
|||
|
|
with open(temp_file_path, 'wb') as f, tqdm(
|
|||
|
|
total=file_size, unit='B', unit_scale=True,
|
|||
|
|
desc=f"下载 {os.path.basename(urlparse(url).path)}"
|
|||
|
|
) as pbar:
|
|||
|
|
for chunk in response.iter_content(chunk_size=chunk_size):
|
|||
|
|
if chunk:
|
|||
|
|
f.write(chunk)
|
|||
|
|
pbar.update(len(chunk))
|
|||
|
|
|
|||
|
|
return temp_file_path
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"下载失败: {str(e)}"
|
|||
|
|
print(error_msg)
|
|||
|
|
if 'temp_file_path' in locals():
|
|||
|
|
try:
|
|||
|
|
os.remove(temp_file_path)
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
raise Exception(error_msg)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def slice_large_image(image_path, slice_size=1024, overlap=100):
|
|||
|
|
"""切分大图为切片、返回切片数据和位置信息"""
|
|||
|
|
img = cv2.imread(image_path)
|
|||
|
|
if img is None:
|
|||
|
|
raise ValueError(f"无法读取图像: {image_path}")
|
|||
|
|
|
|||
|
|
h, w = img.shape[:2]
|
|||
|
|
step = slice_size - overlap
|
|||
|
|
num_rows = (h + step - 1) // step
|
|||
|
|
num_cols = (w + step - 1) // step
|
|||
|
|
|
|||
|
|
slices = []
|
|||
|
|
for i in range(num_rows):
|
|||
|
|
for j in range(num_cols):
|
|||
|
|
y1 = i * step
|
|||
|
|
x1 = j * step
|
|||
|
|
y2 = min(y1 + slice_size, h)
|
|||
|
|
x2 = min(x1 + slice_size, w)
|
|||
|
|
|
|||
|
|
if y2 - y1 < slice_size:
|
|||
|
|
y1 = max(0, y2 - slice_size)
|
|||
|
|
if x2 - x1 < slice_size:
|
|||
|
|
x1 = max(0, x2 - slice_size)
|
|||
|
|
|
|||
|
|
slice_img = img[y1:y2, x1:x2]
|
|||
|
|
slices.append((x1, y1, slice_img))
|
|||
|
|
|
|||
|
|
return slices, (h, w)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_detection_info(result, slice_offset_x, slice_offset_y):
|
|||
|
|
"""从YOLO OBB结果中提取检测框信息(修正宽高计算)"""
|
|||
|
|
detections = []
|
|||
|
|
|
|||
|
|
if result.obb is not None and len(result.obb) > 0:
|
|||
|
|
obb_data = result.obb
|
|||
|
|
obb_xyxy = obb_data.xyxy.cpu().numpy()
|
|||
|
|
classes = obb_data.cls.cpu().numpy()
|
|||
|
|
confidences = obb_data.conf.cpu().numpy()
|
|||
|
|
|
|||
|
|
for i in range(len(obb_data)):
|
|||
|
|
x1_slice, y1_slice, x2_slice, y2_slice = obb_xyxy[i]
|
|||
|
|
# 计算实际宽高(x方向为宽,y方向为高)
|
|||
|
|
width = x2_slice - x1_slice
|
|||
|
|
height = y2_slice - y1_slice
|
|||
|
|
|
|||
|
|
# 转换为全局坐标
|
|||
|
|
x1_global = x1_slice + slice_offset_x
|
|||
|
|
y1_global = y1_slice + slice_offset_y
|
|||
|
|
|
|||
|
|
cls_id = int(classes[i])
|
|||
|
|
confidence = float(confidences[i])
|
|||
|
|
class_name = result.names[cls_id]
|
|||
|
|
|
|||
|
|
detection_info = {
|
|||
|
|
"type": class_name,
|
|||
|
|
"size": [int(round(width)), int(round(height))],
|
|||
|
|
"leftTopPoint": [int(round(x1_global)), int(round(y1_global))],
|
|||
|
|
"score": round(confidence, 4)
|
|||
|
|
}
|
|||
|
|
detections.append(detection_info)
|
|||
|
|
|
|||
|
|
return detections
|
|||
|
|
|
|||
|
|
|
|||
|
|
def detect_large_image_from_url(image_url: str, slice_size: int = 1024, overlap: int = 100) -> DetectionResponse:
|
|||
|
|
"""
|
|||
|
|
封装后的检测方法:从图片URL处理大图、返回DetectionResponse对象
|
|||
|
|
"""
|
|||
|
|
# 动态拼接固定model_path(当前文件同级目录下)
|
|||
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|||
|
|
model_path = os.path.join(current_dir, "models", "solor_bracket_model", "best.pt")
|
|||
|
|
|
|||
|
|
processing_errors = []
|
|||
|
|
all_detections = []
|
|||
|
|
original_size = [0, 0]
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 验证模型文件是否存在
|
|||
|
|
if not os.path.exists(model_path):
|
|||
|
|
raise FileNotFoundError(f"模型文件不存在:{model_path}")
|
|||
|
|
|
|||
|
|
# 下载图像
|
|||
|
|
temp_file_path = download_large_file(image_url)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 切分图像
|
|||
|
|
slices_info, (h, w) = slice_large_image(temp_file_path, slice_size, overlap)
|
|||
|
|
original_size = [w, h]
|
|||
|
|
print(f"完成切片: 共 {len(slices_info)} 个切片")
|
|||
|
|
|
|||
|
|
# 加载模型并预测
|
|||
|
|
model = YOLO(model_path)
|
|||
|
|
print("开始逐张预测切片...")
|
|||
|
|
|
|||
|
|
for i, (x1, y1, slice_img) in enumerate(slices_info, 1):
|
|||
|
|
print(f"预测第 {i}/{len(slices_info)} 个切片")
|
|||
|
|
result = model(slice_img, conf=0.5, verbose=False)[0]
|
|||
|
|
slice_detections = extract_detection_info(result, x1, y1)
|
|||
|
|
all_detections.extend(slice_detections)
|
|||
|
|
print(f" 本切片检测到 {len(slice_detections)} 个目标")
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
# 确保临时文件删除
|
|||
|
|
if os.path.exists(temp_file_path):
|
|||
|
|
try:
|
|||
|
|
os.remove(temp_file_path)
|
|||
|
|
print("临时文件已删除")
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"删除临时文件失败: {str(e)}"
|
|||
|
|
print(error_msg)
|
|||
|
|
processing_errors.append(error_msg)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
# 捕获所有异常并记录
|
|||
|
|
error_msg = str(e)
|
|||
|
|
processing_errors.append(error_msg)
|
|||
|
|
print(f"处理异常: {error_msg}")
|
|||
|
|
|
|||
|
|
# 构建并返回DetectionResponse对象
|
|||
|
|
return DetectionResponse(
|
|||
|
|
hasTarget=1 if len(all_detections) > 0 else 0,
|
|||
|
|
originalImgSize=original_size,
|
|||
|
|
targets=all_detections,
|
|||
|
|
processing_errors=processing_errors
|
|||
|
|
)
|