中煤科工算法合集

This commit is contained in:
2025-12-02 16:43:56 +08:00
commit 0ac74b6892
5 changed files with 801 additions and 0 deletions

66
config.py Normal file
View File

@ -0,0 +1,66 @@
from typing import List
import torch
from pydantic import BaseModel
# 设备配置
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
# 默认检测参数
DEFAULT_CONF = 0.25
DEFAULT_IOU = 0.5
DEFAULT_MIN_SIZE = 8
DEFAULT_POS_THRESH = 5
MODEL_CONFIGS = {
"安全施工模型": {
"model_path": "models/ppe_state_model/best.pt",
"types": ["novest", "nohelmet"],
"type_to_id": {"novest": 0, "nohelmet": 2},
"params": {
"enable_primary": True,
"primary_conf": 0.55,
"secondary_conf": 0.6,
"final_conf": 0.65,
"enable_multi_scale": True,
"multi_scales": [0.75, 1.0, 1.25],
"enable_secondary": True,
"slice_size": 512,
"overlap_ratio": 0.3,
"weight_primary": 0.4,
"weight_secondary": 0.6
}
},
"烟雾火灾模型": {
"model_path": "models/fire_smoke_model/best.pt",
"types": ["fire", "smoke"],
"type_to_id": {"fire": 0, "smoke": 1},
"params": {
"enable_primary": True,
"primary_conf": 0.99,
"secondary_conf": 0.99,
"final_conf": 0.99,
"enable_multi_scale": True,
"multi_scales": [0.75, 1.0, 1.25],
"enable_secondary": True,
"slice_size": 512,
"overlap_ratio": 0.3,
"weight_primary": 0.4,
"weight_secondary": 0.6
}
}
}
# SAHI自适应切片配置
SLICE_RULES = [
(12_000_000, (384, 0.35)),
(3_000_000, (512, 0.3)),
(0, (640, 0.25))
]
class DetectionResponse(BaseModel):
hasTarget: int
originalImgSize: List[int]
targets: List[dict]
processing_errors: List[str] = []

313
detect.py Normal file
View File

@ -0,0 +1,313 @@
import os
from collections import defaultdict
import cv2
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from ultralytics import YOLO
from config import DEVICE, DEFAULT_IOU, DEFAULT_MIN_SIZE, DEFAULT_POS_THRESH, SLICE_RULES, DEFAULT_CONF
class YOLODetector:
def __init__(self, model_path, params, type_to_id):
# 加载YOLO模型
self.model = YOLO(model_path)
self.model.to(DEVICE)
self.class_names = self.model.names
self.type_to_id = type_to_id
self.params = params
self.enable_primary = params.get("enable_primary", True)
self.primary_conf = params.get("primary_conf", DEFAULT_CONF) # 初级检测阈值
self.secondary_conf = params.get("secondary_conf", DEFAULT_CONF) # 次级检测阈值
self.final_conf = params.get("final_conf", DEFAULT_CONF) # 最终展示阈值
# SAHI模型
self.sahi_model = None
if params["enable_secondary"]:
self.sahi_model = AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path=model_path,
confidence_threshold=self.secondary_conf,
device=DEVICE
)
# 统计
self.stats = defaultdict(int)
def get_adaptive_slice(self, total_pixels):
"""自适应切片参数"""
for pixel_thresh, (size, overlap) in SLICE_RULES:
if total_pixels > pixel_thresh:
return size, overlap
return self.params["slice_size"], self.params["overlap_ratio"]
def multi_scale_detect(self, img_path):
"""多尺度检测(使用模型专属初级阈值)"""
detections = []
img = cv2.imread(img_path)
h, w = img.shape[:2]
for scale in self.params["multi_scales"]:
if scale == 1.0:
# 原尺度检测
results = self.model(
img_path,
conf=self.primary_conf, # 模型专属初级阈值
device=DEVICE,
classes=self.target_ids,
verbose=False
)
else:
# 缩放检测
nw, nh = int(w * scale), int(h * scale)
scaled_img = cv2.resize(img, (nw, nh))
temp_path = f"temp_scale_{scale}.jpg"
cv2.imwrite(temp_path, scaled_img)
results = self.model(
temp_path,
conf=self.primary_conf, # 模型专属初级阈值
device=DEVICE,
classes=self.target_ids,
verbose=False
)
os.remove(temp_path)
# 解析结果核心修复增加对result.boxes为None的判断
for result in results:
# 检查boxes是否存在且非空
if result.boxes is None:
continue
for box in result.boxes:
bbox = box.xyxy[0].tolist()
if scale != 1.0:
bbox = [coord / scale for coord in bbox]
detections.append({
"box": bbox,
"conf": box.conf[0].item(),
"class": box.cls[0].item(),
"class_name": self.class_names[int(box.cls[0])],
"source": "primary"
})
return detections
def primary_detect(self, img_path):
"""初次检测(使用模型专属初级阈值)- 新增enable_primary判断"""
# 新增:如果禁用一级检测,直接返回空列表
if not self.enable_primary:
self.stats["primary"] = 0
print(" 一级检测已禁用,跳过初级检测")
return []
if self.params["enable_multi_scale"]:
detections = self.multi_scale_detect(img_path)
else:
results = self.model(
img_path,
conf=self.primary_conf, # 模型专属初级阈值
device=DEVICE,
classes=self.target_ids,
verbose=False
)
# 解析结果核心修复增加对result.boxes为None的判断
detections = []
for result in results:
# 检查boxes是否存在且非空
if result.boxes is None:
continue
for box in result.boxes:
detections.append({
"box": box.xyxy[0].tolist(),
"conf": box.conf[0].item(),
"class": box.cls[0].item(),
"class_name": self.class_names[int(box.cls[0])],
"source": "primary"
})
self.stats["primary"] = len(detections)
return detections
def secondary_detect(self, img_path):
"""SAHI切片检测已在初始化时使用模型专属次级阈值"""
if not self.params["enable_secondary"] or not self.sahi_model:
return []
img = cv2.imread(img_path)
h, w = img.shape[:2]
total_pixels = w * h
slice_size, overlap = self.get_adaptive_slice(total_pixels)
# SAHI切片预测
sliced_results = get_sliced_prediction(
img_path,
self.sahi_model,
slice_height=slice_size,
slice_width=slice_size,
overlap_height_ratio=overlap,
overlap_width_ratio=overlap,
verbose=0
)
detections = []
for obj in sliced_results.object_prediction_list:
if self.target_ids and obj.category.id not in self.target_ids:
continue
bbox = obj.bbox.to_xyxy()
bw, bh = bbox[2] - bbox[0], bbox[3] - bbox[1]
if bw >= DEFAULT_MIN_SIZE and bh >= DEFAULT_MIN_SIZE:
detections.append({
"box": bbox,
"conf": obj.score.value,
"class": obj.category.id,
"class_name": obj.category.name,
"source": "secondary"
})
self.stats["secondary"] = len(detections)
return detections
@staticmethod
def calculate_iou(box1, box2):
"""计算IoU"""
x11, y11, x21, y21 = box1
x12, y12, x22, y22 = box2
inter_x1 = max(x11, x12)
inter_y1 = max(y11, y12)
inter_x2 = min(x21, x22)
inter_y2 = min(y21, y22)
inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
area1 = (x21 - x11) * (y21 - y11)
area2 = (x22 - x12) * (y22 - y12)
union_area = area1 + area2 - inter_area
return inter_area / union_area if union_area > 0 else 0
def merge_detections(self, primary_dets, secondary_dets):
"""融合检测结果"""
if not primary_dets:
return secondary_dets
if not secondary_dets:
return primary_dets
# 加权置信度
all_dets = []
for det in primary_dets:
det["weighted_conf"] = det["conf"] * self.params["weight_primary"]
all_dets.append(det)
for det in secondary_dets:
det["weighted_conf"] = det["conf"] * self.params["weight_secondary"]
all_dets.append(det)
# 按类别分组融合
class_groups = defaultdict(list)
for det in all_dets:
class_groups[det["class"]].append(det)
merged = []
for cls_id, cls_dets in class_groups.items():
cls_dets.sort(key=lambda x: x["weighted_conf"], reverse=True)
suppressed = [False] * len(cls_dets)
for i in range(len(cls_dets)):
if suppressed[i]:
continue
merged.append(cls_dets[i])
for j in range(i + 1, len(cls_dets)):
if not suppressed[j] and self.calculate_iou(cls_dets[i]["box"], cls_dets[j]["box"]) > DEFAULT_IOU:
suppressed[j] = True
self.stats["merged"] = len(merged)
return merged
def post_process(self, detections):
"""后处理(使用模型专属最终阈值)"""
# 置信度过滤:模型专属最终阈值
filtered = [det for det in detections if det["conf"] >= self.final_conf]
# 位置去重
final_dets = []
for curr_det in filtered:
curr_cx = (curr_det["box"][0] + curr_det["box"][2]) / 2
curr_cy = (curr_det["box"][1] + curr_det["box"][3]) / 2
curr_cls = curr_det["class"]
duplicate = False
for idx, exist_det in enumerate(final_dets):
if exist_det["class"] != curr_cls:
continue
exist_cx = (exist_det["box"][0] + exist_det["box"][2]) / 2
exist_cy = (exist_det["box"][1] + exist_det["box"][3]) / 2
dist = np.sqrt((curr_cx - exist_cx) **2 + (curr_cy - exist_cy)** 2)
if dist < DEFAULT_POS_THRESH:
duplicate = True
if curr_det["conf"] > exist_det["conf"]:
final_dets[idx] = curr_det
break
if not duplicate:
final_dets.append(curr_det)
self.stats["final"] = len(final_dets)
return final_dets
def format_results(self, detections):
"""格式化结果"""
formatted = []
for det in detections:
x1, y1, x2, y2 = det["box"]
formatted.append({
"type": det["class_name"],
"size": [int(round(x2 - x1)), int(round(y2 - y1))],
"leftTopPoint": [int(round(x1)), int(round(y1))],
"score": round(det["conf"], 4),
})
return formatted
def get_detection_stats(self):
"""获取检测统计信息"""
return dict(self.stats)
def detect(self, img_path, target_types=None):
"""完整检测流程"""
# 重置统计
self.stats = defaultdict(int)
# 设置目标类别
if target_types:
self.target_ids = [self.type_to_id[cls] for cls in target_types if cls in self.type_to_id]
else:
self.target_ids = None
# 执行检测
primary_dets = self.primary_detect(img_path)
print(f" 初级检测后: {self.stats['primary']} 个目标")
if self.params["enable_secondary"]:
secondary_dets = self.secondary_detect(img_path)
print(f" 次级检测后: {self.stats['secondary']} 个目标")
merged_dets = self.merge_detections(primary_dets, secondary_dets)
print(f" 融合去重后: {self.stats['merged']} 个目标")
else:
merged_dets = primary_dets
print(f" 次级检测未启用")
# 后处理
processed_dets = self.post_process(merged_dets)
print(f" 过滤低置信度后: {self.stats['final']} 个目标")
print(" 最终检测目标详情:")
for idx, det in enumerate(processed_dets, 1):
print(f" 目标{idx} - 类型:{det['class_name']},置信度:{det['conf']:.4f}")
return self.format_results(processed_dets)

126
main.py Normal file
View File

@ -0,0 +1,126 @@
import io
import os
import tempfile
from contextlib import asynccontextmanager
import requests
import uvicorn
from PIL import Image
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, HttpUrl
from config import DetectionResponse
from process import detect_large_image_from_url
# 全局检测管理器
detector_manager = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global detector_manager
try:
from manager import UnifiedDetectionManager
detector_manager = UnifiedDetectionManager()
print("检测管理器初始化成功")
except Exception as e:
print(f"初始化失败:{str(e)}")
raise
yield
app = FastAPI(lifespan=lifespan, title="目标检测API", version="1.0.0")
# 配置跨域请求
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class DetectionRequest(BaseModel):
type: str
url: HttpUrl
class DetectionProcessRequest(BaseModel):
url: HttpUrl
@app.post("/detect_image", response_model=DetectionResponse)
async def run_detection_image(request: DetectionRequest):
# 解析检测类型
requested_types = {t.strip().lower() for t in request.type.split(',') if t.strip()}
print(f"请求的检测类型: {requested_types}")
if not requested_types:
raise HTTPException(status_code=400, detail="未指定检测类型")
# 下载图片
try:
response = requests.get(str(request.url), timeout=15)
response.raise_for_status()
# 获取图片尺寸
with Image.open(io.BytesIO(response.content)) as img:
img_size = [img.width, img.height]
# 创建临时文件
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
temp_file.write(response.content)
temp_path = temp_file.name
except Exception as e:
raise HTTPException(status_code=400, detail=f"图片处理失败:{str(e)}")
# 执行检测
results = []
errors = []
try:
detection_results = detector_manager.detect(temp_path, ",".join(requested_types))
if detection_results:
results = detection_results
except Exception as e:
errors.append(f"检测失111111111111111111111败{str(e)}")
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
return {
"hasTarget": 1 if results else 0,
"originalImgSize": img_size,
"targets": results,
"processing_errors": errors
}
@app.post("/detect_process", response_model=DetectionResponse)
async def run_detection_process(request: DetectionProcessRequest):
return detect_large_image_from_url(str(request.url))
@app.get("/supported_types")
async def get_supported_types():
if detector_manager:
info = detector_manager.get_available_info()
return {
"supported_types": info["supported_types"],
}
return {"supported_types": []}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--reload", action="store_true")
args = parser.parse_args()
uvicorn.run("main:app", host=args.host, port=args.port, reload=args.reload)

116
manager.py Normal file
View File

@ -0,0 +1,116 @@
import os
from collections import defaultdict
from config import MODEL_CONFIGS
from detect import YOLODetector
class UnifiedDetectionManager:
"""统一检测管理器"""
def __init__(self):
self.detectors = {} # 检测器实例
self.type_to_model = {} # 类别到模型映射
self.loaded_models = [] # 已加载模型
self.type_to_id = {} # 全局类别ID映射
self._load_models()
def _load_models(self):
"""加载所有模型"""
if not MODEL_CONFIGS:
raise ValueError("模型配置为空")
for model_name, config in MODEL_CONFIGS.items():
try:
model_path = config["model_path"]
if not os.path.exists(model_path):
print(f"跳过 {model_name}: 模型文件不存在 - {model_path}")
continue
# 创建检测器自动传递新增的enable_primary配置
detector = YOLODetector(
model_path=model_path,
params=config["params"],
type_to_id=config["type_to_id"]
)
# 保存状态
self.detectors[model_name] = detector
self.loaded_models.append(model_name)
# 建立映射
for det_type in config["types"]:
det_type_lower = det_type.lower()
if det_type_lower in self.type_to_model:
print(f"警告: 类别 '{det_type}' 映射冲突")
self.type_to_model[det_type_lower] = model_name
self.type_to_id[det_type_lower] = config["type_to_id"][det_type_lower]
print(f"加载成功: {model_name}")
except Exception as e:
print(f"加载失败 {model_name}: {str(e)}")
continue
print(f"模型加载完成: {len(self.loaded_models)}/{len(MODEL_CONFIGS)}")
print(f"支持类别: {list(self.type_to_model.keys())}")
def parse_types(self, types_str):
"""解析检测类型"""
if not types_str:
raise ValueError("检测类型为空")
# 清理输入
requested_types = list(set(t.strip().lower() for t in types_str.split(',') if t.strip()))
# 按模型分组
model_type_map = defaultdict(list)
for det_type in requested_types:
if det_type in self.type_to_model:
model_name = self.type_to_model[det_type]
model_type_map[model_name].append(det_type)
else:
print(f"忽略未知类别: {det_type}")
if not model_type_map:
raise ValueError("无有效检测类别")
return model_type_map
def detect(self, img_path, detection_types):
"""执行检测"""
if not os.path.exists(img_path):
raise FileNotFoundError(f"图像不存在: {img_path}")
# 解析类型
model_type_map = self.parse_types(detection_types)
# 执行检测自动适配enable_primary配置
all_results = []
for model_name, target_types in model_type_map.items():
if model_name not in self.detectors:
continue
print(f"检测: {model_name} -> {target_types}")
try:
results = self.detectors[model_name].detect(img_path, target_types)
all_results.extend(results)
# 获取详细统计信息
stats = self.detectors[model_name].get_detection_stats()
print(f" {model_name}详细统计: {stats}")
except Exception as e:
print(f"检测失败 {model_name}: {str(e)}")
print(f"检测完成: 总共 {len(all_results)} 个结果")
return all_results
def get_available_info(self):
"""获取可用信息"""
return {
"loaded_models": self.loaded_models,
"supported_types": list(self.type_to_model.keys()),
"type_to_model": self.type_to_model
}

180
process.py Normal file
View File

@ -0,0 +1,180 @@
import os
import tempfile
from typing import List
from urllib.parse import urlparse
import cv2
import requests
from pydantic import BaseModel
from tqdm import tqdm
from ultralytics import YOLO
# 定义返回值模型
class DetectionResponse(BaseModel):
hasTarget: int
originalImgSize: List[int]
targets: List[dict]
processing_errors: List[str] = []
def download_large_file(url, chunk_size=1024 * 1024):
"""下载大型文件到临时文件、返回临时文件路径"""
try:
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
file_size = int(response.headers.get('Content-Length', 0))
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') # 适配png格式
temp_file_path = temp_file.name
temp_file.close()
with open(temp_file_path, 'wb') as f, tqdm(
total=file_size, unit='B', unit_scale=True,
desc=f"下载 {os.path.basename(urlparse(url).path)}"
) as pbar:
for chunk in response.iter_content(chunk_size=chunk_size):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
return temp_file_path
except Exception as e:
error_msg = f"下载失败: {str(e)}"
print(error_msg)
if 'temp_file_path' in locals():
try:
os.remove(temp_file_path)
except:
pass
raise Exception(error_msg)
def slice_large_image(image_path, slice_size=1024, overlap=100):
"""切分大图为切片、返回切片数据和位置信息"""
img = cv2.imread(image_path)
if img is None:
raise ValueError(f"无法读取图像: {image_path}")
h, w = img.shape[:2]
step = slice_size - overlap
num_rows = (h + step - 1) // step
num_cols = (w + step - 1) // step
slices = []
for i in range(num_rows):
for j in range(num_cols):
y1 = i * step
x1 = j * step
y2 = min(y1 + slice_size, h)
x2 = min(x1 + slice_size, w)
if y2 - y1 < slice_size:
y1 = max(0, y2 - slice_size)
if x2 - x1 < slice_size:
x1 = max(0, x2 - slice_size)
slice_img = img[y1:y2, x1:x2]
slices.append((x1, y1, slice_img))
return slices, (h, w)
def extract_detection_info(result, slice_offset_x, slice_offset_y):
"""从YOLO OBB结果中提取检测框信息修正宽高计算"""
detections = []
if result.obb is not None and len(result.obb) > 0:
obb_data = result.obb
obb_xyxy = obb_data.xyxy.cpu().numpy()
classes = obb_data.cls.cpu().numpy()
confidences = obb_data.conf.cpu().numpy()
for i in range(len(obb_data)):
x1_slice, y1_slice, x2_slice, y2_slice = obb_xyxy[i]
# 计算实际宽高x方向为宽y方向为高
width = x2_slice - x1_slice
height = y2_slice - y1_slice
# 转换为全局坐标
x1_global = x1_slice + slice_offset_x
y1_global = y1_slice + slice_offset_y
cls_id = int(classes[i])
confidence = float(confidences[i])
class_name = result.names[cls_id]
detection_info = {
"type": class_name,
"size": [int(round(width)), int(round(height))],
"leftTopPoint": [int(round(x1_global)), int(round(y1_global))],
"score": round(confidence, 4)
}
detections.append(detection_info)
return detections
def detect_large_image_from_url(image_url: str, slice_size: int = 1024, overlap: int = 100) -> DetectionResponse:
"""
封装后的检测方法从图片URL处理大图、返回DetectionResponse对象
"""
# 动态拼接固定model_path当前文件同级目录下
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, "models", "solor_bracket_model", "best.pt")
processing_errors = []
all_detections = []
original_size = [0, 0]
try:
# 验证模型文件是否存在
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在:{model_path}")
# 下载图像
temp_file_path = download_large_file(image_url)
try:
# 切分图像
slices_info, (h, w) = slice_large_image(temp_file_path, slice_size, overlap)
original_size = [w, h]
print(f"完成切片: 共 {len(slices_info)} 个切片")
# 加载模型并预测
model = YOLO(model_path)
print("开始逐张预测切片...")
for i, (x1, y1, slice_img) in enumerate(slices_info, 1):
print(f"预测第 {i}/{len(slices_info)} 个切片")
result = model(slice_img, conf=0.5, verbose=False)[0]
slice_detections = extract_detection_info(result, x1, y1)
all_detections.extend(slice_detections)
print(f" 本切片检测到 {len(slice_detections)} 个目标")
finally:
# 确保临时文件删除
if os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
print("临时文件已删除")
except Exception as e:
error_msg = f"删除临时文件失败: {str(e)}"
print(error_msg)
processing_errors.append(error_msg)
except Exception as e:
# 捕获所有异常并记录
error_msg = str(e)
processing_errors.append(error_msg)
print(f"处理异常: {error_msg}")
# 构建并返回DetectionResponse对象
return DetectionResponse(
hasTarget=1 if len(all_detections) > 0 else 0,
originalImgSize=original_size,
targets=all_detections,
processing_errors=processing_errors
)