Files
AI_agent_detect/process.py

181 lines
5.9 KiB
Python
Raw Normal View History

2025-12-02 17:16:26 +08:00
import os
import tempfile
from typing import List
from urllib.parse import urlparse
import cv2
import requests
from pydantic import BaseModel
from tqdm import tqdm
from ultralytics import YOLO
# 定义返回值模型
class DetectionResponse(BaseModel):
hasTarget: int
originalImgSize: List[int]
targets: List[dict]
processing_errors: List[str] = []
def download_large_file(url, chunk_size=1024 * 1024):
"""下载大型文件到临时文件、返回临时文件路径"""
try:
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
file_size = int(response.headers.get('Content-Length', 0))
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') # 适配png格式
temp_file_path = temp_file.name
temp_file.close()
with open(temp_file_path, 'wb') as f, tqdm(
total=file_size, unit='B', unit_scale=True,
desc=f"下载 {os.path.basename(urlparse(url).path)}"
) as pbar:
for chunk in response.iter_content(chunk_size=chunk_size):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
return temp_file_path
except Exception as e:
error_msg = f"下载失败: {str(e)}"
print(error_msg)
if 'temp_file_path' in locals():
try:
os.remove(temp_file_path)
except:
pass
raise Exception(error_msg)
def slice_large_image(image_path, slice_size=1024, overlap=100):
"""切分大图为切片、返回切片数据和位置信息"""
img = cv2.imread(image_path)
if img is None:
raise ValueError(f"无法读取图像: {image_path}")
h, w = img.shape[:2]
step = slice_size - overlap
num_rows = (h + step - 1) // step
num_cols = (w + step - 1) // step
slices = []
for i in range(num_rows):
for j in range(num_cols):
y1 = i * step
x1 = j * step
y2 = min(y1 + slice_size, h)
x2 = min(x1 + slice_size, w)
if y2 - y1 < slice_size:
y1 = max(0, y2 - slice_size)
if x2 - x1 < slice_size:
x1 = max(0, x2 - slice_size)
slice_img = img[y1:y2, x1:x2]
slices.append((x1, y1, slice_img))
return slices, (h, w)
def extract_detection_info(result, slice_offset_x, slice_offset_y):
"""从YOLO OBB结果中提取检测框信息修正宽高计算"""
detections = []
if result.obb is not None and len(result.obb) > 0:
obb_data = result.obb
obb_xyxy = obb_data.xyxy.cpu().numpy()
classes = obb_data.cls.cpu().numpy()
confidences = obb_data.conf.cpu().numpy()
for i in range(len(obb_data)):
x1_slice, y1_slice, x2_slice, y2_slice = obb_xyxy[i]
# 计算实际宽高x方向为宽y方向为高
width = x2_slice - x1_slice
height = y2_slice - y1_slice
# 转换为全局坐标
x1_global = x1_slice + slice_offset_x
y1_global = y1_slice + slice_offset_y
cls_id = int(classes[i])
confidence = float(confidences[i])
class_name = result.names[cls_id]
detection_info = {
"type": class_name,
"size": [int(round(width)), int(round(height))],
"leftTopPoint": [int(round(x1_global)), int(round(y1_global))],
"score": round(confidence, 4)
}
detections.append(detection_info)
return detections
def detect_large_image_from_url(image_url: str, slice_size: int = 1024, overlap: int = 100) -> DetectionResponse:
"""
封装后的检测方法从图片URL处理大图返回DetectionResponse对象
"""
# 动态拼接固定model_path当前文件同级目录下
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, "models", "solor_bracket_model", "best.pt")
processing_errors = []
all_detections = []
original_size = [0, 0]
try:
# 验证模型文件是否存在
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在:{model_path}")
# 下载图像
temp_file_path = download_large_file(image_url)
try:
# 切分图像
slices_info, (h, w) = slice_large_image(temp_file_path, slice_size, overlap)
original_size = [w, h]
print(f"完成切片: 共 {len(slices_info)} 个切片")
# 加载模型并预测
model = YOLO(model_path)
print("开始逐张预测切片...")
for i, (x1, y1, slice_img) in enumerate(slices_info, 1):
print(f"预测第 {i}/{len(slices_info)} 个切片")
result = model(slice_img, conf=0.5, verbose=False)[0]
slice_detections = extract_detection_info(result, x1, y1)
all_detections.extend(slice_detections)
print(f" 本切片检测到 {len(slice_detections)} 个目标")
finally:
# 确保临时文件删除
if os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
print("临时文件已删除")
except Exception as e:
error_msg = f"删除临时文件失败: {str(e)}"
print(error_msg)
processing_errors.append(error_msg)
except Exception as e:
# 捕获所有异常并记录
error_msg = str(e)
processing_errors.append(error_msg)
print(f"处理异常: {error_msg}")
# 构建并返回DetectionResponse对象
return DetectionResponse(
hasTarget=1 if len(all_detections) > 0 else 0,
originalImgSize=original_size,
targets=all_detections,
processing_errors=processing_errors
)