中煤科工算法合集
This commit is contained in:
66
config.py
Normal file
66
config.py
Normal file
@ -0,0 +1,66 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
||||
# 设备配置
|
||||
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 默认检测参数
|
||||
DEFAULT_CONF = 0.25
|
||||
DEFAULT_IOU = 0.5
|
||||
DEFAULT_MIN_SIZE = 8
|
||||
DEFAULT_POS_THRESH = 5
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
"安全施工模型": {
|
||||
"model_path": "models/ppe_state_model/best.pt",
|
||||
"types": ["novest", "nohelmet"],
|
||||
"type_to_id": {"novest": 0, "nohelmet": 2},
|
||||
"params": {
|
||||
"enable_primary": True,
|
||||
"primary_conf": 0.55,
|
||||
"secondary_conf": 0.6,
|
||||
"final_conf": 0.65,
|
||||
"enable_multi_scale": True,
|
||||
"multi_scales": [0.75, 1.0, 1.25],
|
||||
"enable_secondary": True,
|
||||
"slice_size": 512,
|
||||
"overlap_ratio": 0.3,
|
||||
"weight_primary": 0.4,
|
||||
"weight_secondary": 0.6
|
||||
}
|
||||
},
|
||||
"烟雾火灾模型": {
|
||||
"model_path": "models/fire_smoke_model/best.pt",
|
||||
"types": ["fire", "smoke"],
|
||||
"type_to_id": {"fire": 0, "smoke": 1},
|
||||
"params": {
|
||||
"enable_primary": True,
|
||||
"primary_conf": 0.99,
|
||||
"secondary_conf": 0.99,
|
||||
"final_conf": 0.99,
|
||||
"enable_multi_scale": True,
|
||||
"multi_scales": [0.75, 1.0, 1.25],
|
||||
"enable_secondary": True,
|
||||
"slice_size": 512,
|
||||
"overlap_ratio": 0.3,
|
||||
"weight_primary": 0.4,
|
||||
"weight_secondary": 0.6
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# SAHI自适应切片配置
|
||||
SLICE_RULES = [
|
||||
(12_000_000, (384, 0.35)),
|
||||
(3_000_000, (512, 0.3)),
|
||||
(0, (640, 0.25))
|
||||
]
|
||||
|
||||
|
||||
class DetectionResponse(BaseModel):
|
||||
hasTarget: int
|
||||
originalImgSize: List[int]
|
||||
targets: List[dict]
|
||||
processing_errors: List[str] = []
|
||||
313
detect.py
Normal file
313
detect.py
Normal file
@ -0,0 +1,313 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from sahi import AutoDetectionModel
|
||||
from sahi.predict import get_sliced_prediction
|
||||
from ultralytics import YOLO
|
||||
|
||||
from config import DEVICE, DEFAULT_IOU, DEFAULT_MIN_SIZE, DEFAULT_POS_THRESH, SLICE_RULES, DEFAULT_CONF
|
||||
|
||||
|
||||
class YOLODetector:
|
||||
def __init__(self, model_path, params, type_to_id):
|
||||
# 加载YOLO模型
|
||||
self.model = YOLO(model_path)
|
||||
self.model.to(DEVICE)
|
||||
self.class_names = self.model.names
|
||||
self.type_to_id = type_to_id
|
||||
|
||||
self.params = params
|
||||
self.enable_primary = params.get("enable_primary", True)
|
||||
self.primary_conf = params.get("primary_conf", DEFAULT_CONF) # 初级检测阈值
|
||||
self.secondary_conf = params.get("secondary_conf", DEFAULT_CONF) # 次级检测阈值
|
||||
self.final_conf = params.get("final_conf", DEFAULT_CONF) # 最终展示阈值
|
||||
|
||||
# SAHI模型
|
||||
self.sahi_model = None
|
||||
if params["enable_secondary"]:
|
||||
self.sahi_model = AutoDetectionModel.from_pretrained(
|
||||
model_type='yolov8',
|
||||
model_path=model_path,
|
||||
confidence_threshold=self.secondary_conf,
|
||||
device=DEVICE
|
||||
)
|
||||
|
||||
# 统计
|
||||
self.stats = defaultdict(int)
|
||||
|
||||
def get_adaptive_slice(self, total_pixels):
|
||||
"""自适应切片参数"""
|
||||
for pixel_thresh, (size, overlap) in SLICE_RULES:
|
||||
if total_pixels > pixel_thresh:
|
||||
return size, overlap
|
||||
return self.params["slice_size"], self.params["overlap_ratio"]
|
||||
|
||||
def multi_scale_detect(self, img_path):
|
||||
"""多尺度检测(使用模型专属初级阈值)"""
|
||||
detections = []
|
||||
img = cv2.imread(img_path)
|
||||
h, w = img.shape[:2]
|
||||
|
||||
for scale in self.params["multi_scales"]:
|
||||
if scale == 1.0:
|
||||
# 原尺度检测
|
||||
results = self.model(
|
||||
img_path,
|
||||
conf=self.primary_conf, # 模型专属初级阈值
|
||||
device=DEVICE,
|
||||
classes=self.target_ids,
|
||||
verbose=False
|
||||
)
|
||||
else:
|
||||
# 缩放检测
|
||||
nw, nh = int(w * scale), int(h * scale)
|
||||
scaled_img = cv2.resize(img, (nw, nh))
|
||||
temp_path = f"temp_scale_{scale}.jpg"
|
||||
cv2.imwrite(temp_path, scaled_img)
|
||||
|
||||
results = self.model(
|
||||
temp_path,
|
||||
conf=self.primary_conf, # 模型专属初级阈值
|
||||
device=DEVICE,
|
||||
classes=self.target_ids,
|
||||
verbose=False
|
||||
)
|
||||
os.remove(temp_path)
|
||||
|
||||
# 解析结果(核心修复:增加对result.boxes为None的判断)
|
||||
for result in results:
|
||||
# 检查boxes是否存在且非空
|
||||
if result.boxes is None:
|
||||
continue
|
||||
for box in result.boxes:
|
||||
bbox = box.xyxy[0].tolist()
|
||||
if scale != 1.0:
|
||||
bbox = [coord / scale for coord in bbox]
|
||||
|
||||
detections.append({
|
||||
"box": bbox,
|
||||
"conf": box.conf[0].item(),
|
||||
"class": box.cls[0].item(),
|
||||
"class_name": self.class_names[int(box.cls[0])],
|
||||
"source": "primary"
|
||||
})
|
||||
|
||||
return detections
|
||||
|
||||
def primary_detect(self, img_path):
|
||||
"""初次检测(使用模型专属初级阈值)- 新增enable_primary判断"""
|
||||
# 新增:如果禁用一级检测,直接返回空列表
|
||||
if not self.enable_primary:
|
||||
self.stats["primary"] = 0
|
||||
print(" 一级检测已禁用,跳过初级检测")
|
||||
return []
|
||||
|
||||
if self.params["enable_multi_scale"]:
|
||||
detections = self.multi_scale_detect(img_path)
|
||||
else:
|
||||
results = self.model(
|
||||
img_path,
|
||||
conf=self.primary_conf, # 模型专属初级阈值
|
||||
device=DEVICE,
|
||||
classes=self.target_ids,
|
||||
verbose=False
|
||||
)
|
||||
# 解析结果(核心修复:增加对result.boxes为None的判断)
|
||||
detections = []
|
||||
for result in results:
|
||||
# 检查boxes是否存在且非空
|
||||
if result.boxes is None:
|
||||
continue
|
||||
for box in result.boxes:
|
||||
detections.append({
|
||||
"box": box.xyxy[0].tolist(),
|
||||
"conf": box.conf[0].item(),
|
||||
"class": box.cls[0].item(),
|
||||
"class_name": self.class_names[int(box.cls[0])],
|
||||
"source": "primary"
|
||||
})
|
||||
|
||||
self.stats["primary"] = len(detections)
|
||||
return detections
|
||||
|
||||
def secondary_detect(self, img_path):
|
||||
"""SAHI切片检测(已在初始化时使用模型专属次级阈值)"""
|
||||
if not self.params["enable_secondary"] or not self.sahi_model:
|
||||
return []
|
||||
|
||||
img = cv2.imread(img_path)
|
||||
h, w = img.shape[:2]
|
||||
total_pixels = w * h
|
||||
slice_size, overlap = self.get_adaptive_slice(total_pixels)
|
||||
|
||||
# SAHI切片预测
|
||||
sliced_results = get_sliced_prediction(
|
||||
img_path,
|
||||
self.sahi_model,
|
||||
slice_height=slice_size,
|
||||
slice_width=slice_size,
|
||||
overlap_height_ratio=overlap,
|
||||
overlap_width_ratio=overlap,
|
||||
verbose=0
|
||||
)
|
||||
|
||||
detections = []
|
||||
for obj in sliced_results.object_prediction_list:
|
||||
if self.target_ids and obj.category.id not in self.target_ids:
|
||||
continue
|
||||
|
||||
bbox = obj.bbox.to_xyxy()
|
||||
bw, bh = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||
|
||||
if bw >= DEFAULT_MIN_SIZE and bh >= DEFAULT_MIN_SIZE:
|
||||
detections.append({
|
||||
"box": bbox,
|
||||
"conf": obj.score.value,
|
||||
"class": obj.category.id,
|
||||
"class_name": obj.category.name,
|
||||
"source": "secondary"
|
||||
})
|
||||
|
||||
self.stats["secondary"] = len(detections)
|
||||
return detections
|
||||
|
||||
@staticmethod
|
||||
def calculate_iou(box1, box2):
|
||||
"""计算IoU"""
|
||||
x11, y11, x21, y21 = box1
|
||||
x12, y12, x22, y22 = box2
|
||||
|
||||
inter_x1 = max(x11, x12)
|
||||
inter_y1 = max(y11, y12)
|
||||
inter_x2 = min(x21, x22)
|
||||
inter_y2 = min(y21, y22)
|
||||
|
||||
inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
|
||||
area1 = (x21 - x11) * (y21 - y11)
|
||||
area2 = (x22 - x12) * (y22 - y12)
|
||||
union_area = area1 + area2 - inter_area
|
||||
|
||||
return inter_area / union_area if union_area > 0 else 0
|
||||
|
||||
def merge_detections(self, primary_dets, secondary_dets):
|
||||
"""融合检测结果"""
|
||||
if not primary_dets:
|
||||
return secondary_dets
|
||||
if not secondary_dets:
|
||||
return primary_dets
|
||||
|
||||
# 加权置信度
|
||||
all_dets = []
|
||||
for det in primary_dets:
|
||||
det["weighted_conf"] = det["conf"] * self.params["weight_primary"]
|
||||
all_dets.append(det)
|
||||
for det in secondary_dets:
|
||||
det["weighted_conf"] = det["conf"] * self.params["weight_secondary"]
|
||||
all_dets.append(det)
|
||||
|
||||
# 按类别分组融合
|
||||
class_groups = defaultdict(list)
|
||||
for det in all_dets:
|
||||
class_groups[det["class"]].append(det)
|
||||
|
||||
merged = []
|
||||
for cls_id, cls_dets in class_groups.items():
|
||||
cls_dets.sort(key=lambda x: x["weighted_conf"], reverse=True)
|
||||
suppressed = [False] * len(cls_dets)
|
||||
|
||||
for i in range(len(cls_dets)):
|
||||
if suppressed[i]:
|
||||
continue
|
||||
merged.append(cls_dets[i])
|
||||
for j in range(i + 1, len(cls_dets)):
|
||||
if not suppressed[j] and self.calculate_iou(cls_dets[i]["box"], cls_dets[j]["box"]) > DEFAULT_IOU:
|
||||
suppressed[j] = True
|
||||
|
||||
self.stats["merged"] = len(merged)
|
||||
return merged
|
||||
|
||||
def post_process(self, detections):
|
||||
"""后处理(使用模型专属最终阈值)"""
|
||||
# 置信度过滤:模型专属最终阈值
|
||||
filtered = [det for det in detections if det["conf"] >= self.final_conf]
|
||||
|
||||
# 位置去重
|
||||
final_dets = []
|
||||
for curr_det in filtered:
|
||||
curr_cx = (curr_det["box"][0] + curr_det["box"][2]) / 2
|
||||
curr_cy = (curr_det["box"][1] + curr_det["box"][3]) / 2
|
||||
curr_cls = curr_det["class"]
|
||||
duplicate = False
|
||||
|
||||
for idx, exist_det in enumerate(final_dets):
|
||||
if exist_det["class"] != curr_cls:
|
||||
continue
|
||||
|
||||
exist_cx = (exist_det["box"][0] + exist_det["box"][2]) / 2
|
||||
exist_cy = (exist_det["box"][1] + exist_det["box"][3]) / 2
|
||||
dist = np.sqrt((curr_cx - exist_cx) **2 + (curr_cy - exist_cy)** 2)
|
||||
|
||||
if dist < DEFAULT_POS_THRESH:
|
||||
duplicate = True
|
||||
if curr_det["conf"] > exist_det["conf"]:
|
||||
final_dets[idx] = curr_det
|
||||
break
|
||||
|
||||
if not duplicate:
|
||||
final_dets.append(curr_det)
|
||||
|
||||
self.stats["final"] = len(final_dets)
|
||||
return final_dets
|
||||
|
||||
def format_results(self, detections):
|
||||
"""格式化结果"""
|
||||
formatted = []
|
||||
for det in detections:
|
||||
x1, y1, x2, y2 = det["box"]
|
||||
formatted.append({
|
||||
"type": det["class_name"],
|
||||
"size": [int(round(x2 - x1)), int(round(y2 - y1))],
|
||||
"leftTopPoint": [int(round(x1)), int(round(y1))],
|
||||
"score": round(det["conf"], 4),
|
||||
})
|
||||
return formatted
|
||||
|
||||
def get_detection_stats(self):
|
||||
"""获取检测统计信息"""
|
||||
return dict(self.stats)
|
||||
|
||||
def detect(self, img_path, target_types=None):
|
||||
"""完整检测流程"""
|
||||
# 重置统计
|
||||
self.stats = defaultdict(int)
|
||||
|
||||
# 设置目标类别
|
||||
if target_types:
|
||||
self.target_ids = [self.type_to_id[cls] for cls in target_types if cls in self.type_to_id]
|
||||
else:
|
||||
self.target_ids = None
|
||||
|
||||
# 执行检测
|
||||
primary_dets = self.primary_detect(img_path)
|
||||
print(f" 初级检测后: {self.stats['primary']} 个目标")
|
||||
|
||||
if self.params["enable_secondary"]:
|
||||
secondary_dets = self.secondary_detect(img_path)
|
||||
print(f" 次级检测后: {self.stats['secondary']} 个目标")
|
||||
merged_dets = self.merge_detections(primary_dets, secondary_dets)
|
||||
print(f" 融合去重后: {self.stats['merged']} 个目标")
|
||||
else:
|
||||
merged_dets = primary_dets
|
||||
print(f" 次级检测未启用")
|
||||
|
||||
# 后处理
|
||||
processed_dets = self.post_process(merged_dets)
|
||||
print(f" 过滤低置信度后: {self.stats['final']} 个目标")
|
||||
|
||||
print(" 最终检测目标详情:")
|
||||
for idx, det in enumerate(processed_dets, 1):
|
||||
print(f" 目标{idx} - 类型:{det['class_name']},置信度:{det['conf']:.4f}")
|
||||
|
||||
return self.format_results(processed_dets)
|
||||
126
main.py
Normal file
126
main.py
Normal file
@ -0,0 +1,126 @@
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import requests
|
||||
import uvicorn
|
||||
from PIL import Image
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
from config import DetectionResponse
|
||||
from process import detect_large_image_from_url
|
||||
|
||||
# 全局检测管理器
|
||||
detector_manager = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global detector_manager
|
||||
try:
|
||||
from manager import UnifiedDetectionManager
|
||||
detector_manager = UnifiedDetectionManager()
|
||||
print("检测管理器初始化成功")
|
||||
except Exception as e:
|
||||
print(f"初始化失败:{str(e)}")
|
||||
raise
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan, title="目标检测API", version="1.0.0")
|
||||
|
||||
# 配置跨域请求
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
class DetectionRequest(BaseModel):
|
||||
type: str
|
||||
url: HttpUrl
|
||||
|
||||
|
||||
class DetectionProcessRequest(BaseModel):
|
||||
url: HttpUrl
|
||||
|
||||
|
||||
@app.post("/detect_image", response_model=DetectionResponse)
|
||||
async def run_detection_image(request: DetectionRequest):
|
||||
# 解析检测类型
|
||||
requested_types = {t.strip().lower() for t in request.type.split(',') if t.strip()}
|
||||
print(f"请求的检测类型: {requested_types}")
|
||||
if not requested_types:
|
||||
raise HTTPException(status_code=400, detail="未指定检测类型")
|
||||
|
||||
# 下载图片
|
||||
try:
|
||||
response = requests.get(str(request.url), timeout=15)
|
||||
response.raise_for_status()
|
||||
|
||||
# 获取图片尺寸
|
||||
with Image.open(io.BytesIO(response.content)) as img:
|
||||
img_size = [img.width, img.height]
|
||||
|
||||
# 创建临时文件
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
|
||||
temp_file.write(response.content)
|
||||
temp_path = temp_file.name
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"图片处理失败:{str(e)}")
|
||||
|
||||
# 执行检测
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
try:
|
||||
detection_results = detector_manager.detect(temp_path, ",".join(requested_types))
|
||||
if detection_results:
|
||||
results = detection_results
|
||||
except Exception as e:
|
||||
errors.append(f"检测失111111111111111111111败:{str(e)}")
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
return {
|
||||
"hasTarget": 1 if results else 0,
|
||||
"originalImgSize": img_size,
|
||||
"targets": results,
|
||||
"processing_errors": errors
|
||||
}
|
||||
|
||||
|
||||
@app.post("/detect_process", response_model=DetectionResponse)
|
||||
async def run_detection_process(request: DetectionProcessRequest):
|
||||
return detect_large_image_from_url(str(request.url))
|
||||
|
||||
|
||||
@app.get("/supported_types")
|
||||
async def get_supported_types():
|
||||
if detector_manager:
|
||||
info = detector_manager.get_available_info()
|
||||
return {
|
||||
"supported_types": info["supported_types"],
|
||||
}
|
||||
return {"supported_types": []}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--reload", action="store_true")
|
||||
args = parser.parse_args()
|
||||
uvicorn.run("main:app", host=args.host, port=args.port, reload=args.reload)
|
||||
116
manager.py
Normal file
116
manager.py
Normal file
@ -0,0 +1,116 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
from config import MODEL_CONFIGS
|
||||
from detect import YOLODetector
|
||||
|
||||
|
||||
class UnifiedDetectionManager:
|
||||
"""统一检测管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.detectors = {} # 检测器实例
|
||||
self.type_to_model = {} # 类别到模型映射
|
||||
self.loaded_models = [] # 已加载模型
|
||||
self.type_to_id = {} # 全局类别ID映射
|
||||
|
||||
self._load_models()
|
||||
|
||||
def _load_models(self):
|
||||
"""加载所有模型"""
|
||||
if not MODEL_CONFIGS:
|
||||
raise ValueError("模型配置为空")
|
||||
|
||||
for model_name, config in MODEL_CONFIGS.items():
|
||||
try:
|
||||
model_path = config["model_path"]
|
||||
if not os.path.exists(model_path):
|
||||
print(f"跳过 {model_name}: 模型文件不存在 - {model_path}")
|
||||
continue
|
||||
|
||||
# 创建检测器(自动传递新增的enable_primary配置)
|
||||
detector = YOLODetector(
|
||||
model_path=model_path,
|
||||
params=config["params"],
|
||||
type_to_id=config["type_to_id"]
|
||||
)
|
||||
|
||||
# 保存状态
|
||||
self.detectors[model_name] = detector
|
||||
self.loaded_models.append(model_name)
|
||||
|
||||
# 建立映射
|
||||
for det_type in config["types"]:
|
||||
det_type_lower = det_type.lower()
|
||||
if det_type_lower in self.type_to_model:
|
||||
print(f"警告: 类别 '{det_type}' 映射冲突")
|
||||
self.type_to_model[det_type_lower] = model_name
|
||||
self.type_to_id[det_type_lower] = config["type_to_id"][det_type_lower]
|
||||
|
||||
print(f"加载成功: {model_name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"加载失败 {model_name}: {str(e)}")
|
||||
continue
|
||||
|
||||
print(f"模型加载完成: {len(self.loaded_models)}/{len(MODEL_CONFIGS)}")
|
||||
print(f"支持类别: {list(self.type_to_model.keys())}")
|
||||
|
||||
def parse_types(self, types_str):
|
||||
"""解析检测类型"""
|
||||
if not types_str:
|
||||
raise ValueError("检测类型为空")
|
||||
|
||||
# 清理输入
|
||||
requested_types = list(set(t.strip().lower() for t in types_str.split(',') if t.strip()))
|
||||
|
||||
# 按模型分组
|
||||
model_type_map = defaultdict(list)
|
||||
for det_type in requested_types:
|
||||
if det_type in self.type_to_model:
|
||||
model_name = self.type_to_model[det_type]
|
||||
model_type_map[model_name].append(det_type)
|
||||
else:
|
||||
print(f"忽略未知类别: {det_type}")
|
||||
|
||||
if not model_type_map:
|
||||
raise ValueError("无有效检测类别")
|
||||
|
||||
return model_type_map
|
||||
|
||||
def detect(self, img_path, detection_types):
|
||||
"""执行检测"""
|
||||
if not os.path.exists(img_path):
|
||||
raise FileNotFoundError(f"图像不存在: {img_path}")
|
||||
|
||||
# 解析类型
|
||||
model_type_map = self.parse_types(detection_types)
|
||||
|
||||
# 执行检测(自动适配enable_primary配置)
|
||||
all_results = []
|
||||
for model_name, target_types in model_type_map.items():
|
||||
if model_name not in self.detectors:
|
||||
continue
|
||||
|
||||
print(f"检测: {model_name} -> {target_types}")
|
||||
try:
|
||||
results = self.detectors[model_name].detect(img_path, target_types)
|
||||
all_results.extend(results)
|
||||
|
||||
# 获取详细统计信息
|
||||
stats = self.detectors[model_name].get_detection_stats()
|
||||
print(f" {model_name}详细统计: {stats}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"检测失败 {model_name}: {str(e)}")
|
||||
|
||||
print(f"检测完成: 总共 {len(all_results)} 个结果")
|
||||
return all_results
|
||||
|
||||
def get_available_info(self):
|
||||
"""获取可用信息"""
|
||||
return {
|
||||
"loaded_models": self.loaded_models,
|
||||
"supported_types": list(self.type_to_model.keys()),
|
||||
"type_to_model": self.type_to_model
|
||||
}
|
||||
180
process.py
Normal file
180
process.py
Normal file
@ -0,0 +1,180 @@
|
||||
import os
|
||||
import tempfile
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import cv2
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
from ultralytics import YOLO
|
||||
|
||||
|
||||
# 定义返回值模型
|
||||
class DetectionResponse(BaseModel):
|
||||
hasTarget: int
|
||||
originalImgSize: List[int]
|
||||
targets: List[dict]
|
||||
processing_errors: List[str] = []
|
||||
|
||||
|
||||
def download_large_file(url, chunk_size=1024 * 1024):
|
||||
"""下载大型文件到临时文件、返回临时文件路径"""
|
||||
try:
|
||||
response = requests.get(url, stream=True, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
file_size = int(response.headers.get('Content-Length', 0))
|
||||
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') # 适配png格式
|
||||
temp_file_path = temp_file.name
|
||||
temp_file.close()
|
||||
|
||||
with open(temp_file_path, 'wb') as f, tqdm(
|
||||
total=file_size, unit='B', unit_scale=True,
|
||||
desc=f"下载 {os.path.basename(urlparse(url).path)}"
|
||||
) as pbar:
|
||||
for chunk in response.iter_content(chunk_size=chunk_size):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
pbar.update(len(chunk))
|
||||
|
||||
return temp_file_path
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"下载失败: {str(e)}"
|
||||
print(error_msg)
|
||||
if 'temp_file_path' in locals():
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
except:
|
||||
pass
|
||||
raise Exception(error_msg)
|
||||
|
||||
|
||||
def slice_large_image(image_path, slice_size=1024, overlap=100):
|
||||
"""切分大图为切片、返回切片数据和位置信息"""
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
raise ValueError(f"无法读取图像: {image_path}")
|
||||
|
||||
h, w = img.shape[:2]
|
||||
step = slice_size - overlap
|
||||
num_rows = (h + step - 1) // step
|
||||
num_cols = (w + step - 1) // step
|
||||
|
||||
slices = []
|
||||
for i in range(num_rows):
|
||||
for j in range(num_cols):
|
||||
y1 = i * step
|
||||
x1 = j * step
|
||||
y2 = min(y1 + slice_size, h)
|
||||
x2 = min(x1 + slice_size, w)
|
||||
|
||||
if y2 - y1 < slice_size:
|
||||
y1 = max(0, y2 - slice_size)
|
||||
if x2 - x1 < slice_size:
|
||||
x1 = max(0, x2 - slice_size)
|
||||
|
||||
slice_img = img[y1:y2, x1:x2]
|
||||
slices.append((x1, y1, slice_img))
|
||||
|
||||
return slices, (h, w)
|
||||
|
||||
|
||||
def extract_detection_info(result, slice_offset_x, slice_offset_y):
|
||||
"""从YOLO OBB结果中提取检测框信息(修正宽高计算)"""
|
||||
detections = []
|
||||
|
||||
if result.obb is not None and len(result.obb) > 0:
|
||||
obb_data = result.obb
|
||||
obb_xyxy = obb_data.xyxy.cpu().numpy()
|
||||
classes = obb_data.cls.cpu().numpy()
|
||||
confidences = obb_data.conf.cpu().numpy()
|
||||
|
||||
for i in range(len(obb_data)):
|
||||
x1_slice, y1_slice, x2_slice, y2_slice = obb_xyxy[i]
|
||||
# 计算实际宽高(x方向为宽,y方向为高)
|
||||
width = x2_slice - x1_slice
|
||||
height = y2_slice - y1_slice
|
||||
|
||||
# 转换为全局坐标
|
||||
x1_global = x1_slice + slice_offset_x
|
||||
y1_global = y1_slice + slice_offset_y
|
||||
|
||||
cls_id = int(classes[i])
|
||||
confidence = float(confidences[i])
|
||||
class_name = result.names[cls_id]
|
||||
|
||||
detection_info = {
|
||||
"type": class_name,
|
||||
"size": [int(round(width)), int(round(height))],
|
||||
"leftTopPoint": [int(round(x1_global)), int(round(y1_global))],
|
||||
"score": round(confidence, 4)
|
||||
}
|
||||
detections.append(detection_info)
|
||||
|
||||
return detections
|
||||
|
||||
|
||||
def detect_large_image_from_url(image_url: str, slice_size: int = 1024, overlap: int = 100) -> DetectionResponse:
|
||||
"""
|
||||
封装后的检测方法:从图片URL处理大图、返回DetectionResponse对象
|
||||
"""
|
||||
# 动态拼接固定model_path(当前文件同级目录下)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, "models", "solor_bracket_model", "best.pt")
|
||||
|
||||
processing_errors = []
|
||||
all_detections = []
|
||||
original_size = [0, 0]
|
||||
|
||||
try:
|
||||
# 验证模型文件是否存在
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件不存在:{model_path}")
|
||||
|
||||
# 下载图像
|
||||
temp_file_path = download_large_file(image_url)
|
||||
|
||||
try:
|
||||
# 切分图像
|
||||
slices_info, (h, w) = slice_large_image(temp_file_path, slice_size, overlap)
|
||||
original_size = [w, h]
|
||||
print(f"完成切片: 共 {len(slices_info)} 个切片")
|
||||
|
||||
# 加载模型并预测
|
||||
model = YOLO(model_path)
|
||||
print("开始逐张预测切片...")
|
||||
|
||||
for i, (x1, y1, slice_img) in enumerate(slices_info, 1):
|
||||
print(f"预测第 {i}/{len(slices_info)} 个切片")
|
||||
result = model(slice_img, conf=0.5, verbose=False)[0]
|
||||
slice_detections = extract_detection_info(result, x1, y1)
|
||||
all_detections.extend(slice_detections)
|
||||
print(f" 本切片检测到 {len(slice_detections)} 个目标")
|
||||
|
||||
finally:
|
||||
# 确保临时文件删除
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
print("临时文件已删除")
|
||||
except Exception as e:
|
||||
error_msg = f"删除临时文件失败: {str(e)}"
|
||||
print(error_msg)
|
||||
processing_errors.append(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
# 捕获所有异常并记录
|
||||
error_msg = str(e)
|
||||
processing_errors.append(error_msg)
|
||||
print(f"处理异常: {error_msg}")
|
||||
|
||||
# 构建并返回DetectionResponse对象
|
||||
return DetectionResponse(
|
||||
hasTarget=1 if len(all_detections) > 0 else 0,
|
||||
originalImgSize=original_size,
|
||||
targets=all_detections,
|
||||
processing_errors=processing_errors
|
||||
)
|
||||
Reference in New Issue
Block a user