中煤科工算法合集
This commit is contained in:
180
process.py
Normal file
180
process.py
Normal file
@ -0,0 +1,180 @@
|
||||
import os
|
||||
import tempfile
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import cv2
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
from ultralytics import YOLO
|
||||
|
||||
|
||||
# 定义返回值模型
|
||||
class DetectionResponse(BaseModel):
|
||||
hasTarget: int
|
||||
originalImgSize: List[int]
|
||||
targets: List[dict]
|
||||
processing_errors: List[str] = []
|
||||
|
||||
|
||||
def download_large_file(url, chunk_size=1024 * 1024):
|
||||
"""下载大型文件到临时文件、返回临时文件路径"""
|
||||
try:
|
||||
response = requests.get(url, stream=True, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
file_size = int(response.headers.get('Content-Length', 0))
|
||||
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') # 适配png格式
|
||||
temp_file_path = temp_file.name
|
||||
temp_file.close()
|
||||
|
||||
with open(temp_file_path, 'wb') as f, tqdm(
|
||||
total=file_size, unit='B', unit_scale=True,
|
||||
desc=f"下载 {os.path.basename(urlparse(url).path)}"
|
||||
) as pbar:
|
||||
for chunk in response.iter_content(chunk_size=chunk_size):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
pbar.update(len(chunk))
|
||||
|
||||
return temp_file_path
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"下载失败: {str(e)}"
|
||||
print(error_msg)
|
||||
if 'temp_file_path' in locals():
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
except:
|
||||
pass
|
||||
raise Exception(error_msg)
|
||||
|
||||
|
||||
def slice_large_image(image_path, slice_size=1024, overlap=100):
|
||||
"""切分大图为切片、返回切片数据和位置信息"""
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
raise ValueError(f"无法读取图像: {image_path}")
|
||||
|
||||
h, w = img.shape[:2]
|
||||
step = slice_size - overlap
|
||||
num_rows = (h + step - 1) // step
|
||||
num_cols = (w + step - 1) // step
|
||||
|
||||
slices = []
|
||||
for i in range(num_rows):
|
||||
for j in range(num_cols):
|
||||
y1 = i * step
|
||||
x1 = j * step
|
||||
y2 = min(y1 + slice_size, h)
|
||||
x2 = min(x1 + slice_size, w)
|
||||
|
||||
if y2 - y1 < slice_size:
|
||||
y1 = max(0, y2 - slice_size)
|
||||
if x2 - x1 < slice_size:
|
||||
x1 = max(0, x2 - slice_size)
|
||||
|
||||
slice_img = img[y1:y2, x1:x2]
|
||||
slices.append((x1, y1, slice_img))
|
||||
|
||||
return slices, (h, w)
|
||||
|
||||
|
||||
def extract_detection_info(result, slice_offset_x, slice_offset_y):
|
||||
"""从YOLO OBB结果中提取检测框信息(修正宽高计算)"""
|
||||
detections = []
|
||||
|
||||
if result.obb is not None and len(result.obb) > 0:
|
||||
obb_data = result.obb
|
||||
obb_xyxy = obb_data.xyxy.cpu().numpy()
|
||||
classes = obb_data.cls.cpu().numpy()
|
||||
confidences = obb_data.conf.cpu().numpy()
|
||||
|
||||
for i in range(len(obb_data)):
|
||||
x1_slice, y1_slice, x2_slice, y2_slice = obb_xyxy[i]
|
||||
# 计算实际宽高(x方向为宽,y方向为高)
|
||||
width = x2_slice - x1_slice
|
||||
height = y2_slice - y1_slice
|
||||
|
||||
# 转换为全局坐标
|
||||
x1_global = x1_slice + slice_offset_x
|
||||
y1_global = y1_slice + slice_offset_y
|
||||
|
||||
cls_id = int(classes[i])
|
||||
confidence = float(confidences[i])
|
||||
class_name = result.names[cls_id]
|
||||
|
||||
detection_info = {
|
||||
"type": class_name,
|
||||
"size": [int(round(width)), int(round(height))],
|
||||
"leftTopPoint": [int(round(x1_global)), int(round(y1_global))],
|
||||
"score": round(confidence, 4)
|
||||
}
|
||||
detections.append(detection_info)
|
||||
|
||||
return detections
|
||||
|
||||
|
||||
def detect_large_image_from_url(image_url: str, slice_size: int = 1024, overlap: int = 100) -> DetectionResponse:
|
||||
"""
|
||||
封装后的检测方法:从图片URL处理大图、返回DetectionResponse对象
|
||||
"""
|
||||
# 动态拼接固定model_path(当前文件同级目录下)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, "models", "solor_bracket_model", "best.pt")
|
||||
|
||||
processing_errors = []
|
||||
all_detections = []
|
||||
original_size = [0, 0]
|
||||
|
||||
try:
|
||||
# 验证模型文件是否存在
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件不存在:{model_path}")
|
||||
|
||||
# 下载图像
|
||||
temp_file_path = download_large_file(image_url)
|
||||
|
||||
try:
|
||||
# 切分图像
|
||||
slices_info, (h, w) = slice_large_image(temp_file_path, slice_size, overlap)
|
||||
original_size = [w, h]
|
||||
print(f"完成切片: 共 {len(slices_info)} 个切片")
|
||||
|
||||
# 加载模型并预测
|
||||
model = YOLO(model_path)
|
||||
print("开始逐张预测切片...")
|
||||
|
||||
for i, (x1, y1, slice_img) in enumerate(slices_info, 1):
|
||||
print(f"预测第 {i}/{len(slices_info)} 个切片")
|
||||
result = model(slice_img, conf=0.5, verbose=False)[0]
|
||||
slice_detections = extract_detection_info(result, x1, y1)
|
||||
all_detections.extend(slice_detections)
|
||||
print(f" 本切片检测到 {len(slice_detections)} 个目标")
|
||||
|
||||
finally:
|
||||
# 确保临时文件删除
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
print("临时文件已删除")
|
||||
except Exception as e:
|
||||
error_msg = f"删除临时文件失败: {str(e)}"
|
||||
print(error_msg)
|
||||
processing_errors.append(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
# 捕获所有异常并记录
|
||||
error_msg = str(e)
|
||||
processing_errors.append(error_msg)
|
||||
print(f"处理异常: {error_msg}")
|
||||
|
||||
# 构建并返回DetectionResponse对象
|
||||
return DetectionResponse(
|
||||
hasTarget=1 if len(all_detections) > 0 else 0,
|
||||
originalImgSize=original_size,
|
||||
targets=all_detections,
|
||||
processing_errors=processing_errors
|
||||
)
|
||||
Reference in New Issue
Block a user