智能体加检测

This commit is contained in:
2025-12-02 17:16:26 +08:00
commit 61c3f26946
15 changed files with 1137 additions and 0 deletions

8
.idea/.gitignore generated vendored Normal file
View File

@ -0,0 +1,8 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

329
AI_Agent.py Normal file
View File

@ -0,0 +1,329 @@
import os
import re
import time
import subprocess
from pathlib import Path
from typing import List, Dict
from docx import Document
from shutil import which
import requests
INPUT_WORD = r"C:\Users\YC\Desktop\1.docx" # 你的招标文件
OUTPUT_WORD = r"C:\Users\YC\Desktop\投标文件-最终版.docx" # 最终输出路径
OLLAMA_MODEL = "alibayram/Qwen3-30B-A3B-Instruct-2507:latest" # 当前最强本地模型
OLLAMA_BASE_URL = "http://192.168.110.5:11434"
# ==================== Ollama 本地调用(支持 128K 上下文 + 长输出)===================
# ==================== 终极稳版 call_llm彻底解决超时 + 支持所有参数)===================
def call_llm(messages: List[Dict], temperature=0.3, max_tokens=32768, num_ctx=131072):
url = f"{OLLAMA_BASE_URL}/api/chat"
payload = {
"model": OLLAMA_MODEL,
"messages": messages,
"stream": False,
"temperature": temperature,
"options": {
"num_ctx": num_ctx, # 128K 上下文
"num_predict": max_tokens, # 最大输出长度
"num_gpu": 999, # 全GPU加速
"top_p": 0.95,
"top_k": 40,
"repeat_penalty": 1.08,
"mirostat": 2,
"mirostat_tau": 5.0
}
}
headers = {"Content-Type": "application/json"}
# 最多重试 6 次,指数退避
for attempt in range(6):
try:
print(f" → 正在调用模型(第{attempt + 1}次尝试最大等待15分钟...")
response = requests.post(
url,
json=payload,
headers=headers,
timeout=900 # 关键15分钟超时足够生成目录了
)
response.raise_for_status()
data = response.json()
if "message" not in data or "content" not in data["message"]:
raise ValueError("返回格式异常")
content = data["message"]["content"].strip()
print(f" √ 模型返回成功,本次生成约 {len(content) // 2}")
return content
except requests.exceptions.Timeout:
print(f" ×{attempt + 1}次超时15分钟未返回10秒后重试...")
time.sleep(10)
except requests.exceptions.RequestException as e:
print(f" ×{attempt + 1}次网络错误:{e}10秒后重试...")
time.sleep(10)
except Exception as e:
print(f" × 未知错误:{e}")
time.sleep(5)
print(" × 模型彻底失联,返回保底内容")
return "【模型响应失败,已启用保底方案】"
# ==================== Word → Markdown不变超稳===================
def word_to_md(word_path: str) -> str:
md_path = os.path.splitext(word_path)[0] + "_tender.md"
print(f"正在转换招标文件 → Markdown{os.path.basename(word_path)}")
pandoc_cmd = which("pandoc") or which("pandoc.exe")
if not pandoc_cmd:
common = [
os.path.expanduser(r"~\AppData\Local\Pandoc\pandoc.exe"),
r"C:\Program Files\Pandoc\pandoc.exe",
]
for p in common:
if os.path.exists(p):
pandoc_cmd = p
break
if pandoc_cmd:
result = subprocess.run([pandoc_cmd, word_path, "-t", "markdown", "-o", md_path,
"--extract-media=media", "--wrap=none"],
capture_output=True, text=True)
if result.returncode == 0:
print("Pandoc 转换成功!")
return md_path
print("Pandoc 未找到,使用 python-docx 兜底...")
doc = Document(word_path)
text = "\n\n".join(p.text for p in doc.paragraphs if p.text.strip())
Path(md_path).write_text(text, encoding="utf-8")
print("纯文本提取完成!")
return md_path
# ==================== 生成超详细四级目录(利用 128K 上下文)===================
# ==================== 生成超详细四级目录(已修复语法 + 增强稳定性)===================
# ==================== 新版:两步生成超级目录(永不超时)===================
def generate_full_outline(tender_md: str) -> str:
tender_text = Path(tender_md).read_text(encoding="utf-8")
print(f"招标文件共 {len(tender_text)//2} 字,开始两阶段生成四级目录...")
# 第一步:先让模型只看前 6 万字生成一个【简洁但完整】的三级目录超快10秒内出
prompt1 = f"""请仔细阅读以下招标文件核心内容只输出一个简洁但完整的三级目录一级用“一、”二级用“1、”三级用“1.1、”)。
不要四级标题,不要任何说明文字,不要页码。
招标文件摘录(最关键部分):
{tender_text[:60000]}
直接输出三级目录:"""
print("第1步生成三级骨架10秒内必出...")
outline_skeleton = call_llm([{"role": "user", "content": prompt1}],
temperature=0.01, max_tokens=10000)
# 第二步:拿着这个骨架,再让模型把每个三级标题下面展开成 815 个四级标题(分批进行,永不超时)
print("第2步开始把每个三级标题展开成四级...")
final_lines = []
level3_titles = []
current_level3 = ""
for line in outline_skeleton.split('\n'):
line = line.strip()
if re.match(r'^\d+\.\d+、', line) or re.match(r'^\d+\.\d+ ', line):
current_level3 = line
level3_titles.append(current_level3)
final_lines.append(line) # 三级原样保留
elif line and not line.startswith(('一、', '二、', '三、', '四、', '五、', '六、', '七、', '八、')):
final_lines.append(line)
# 每 8个三级标题为一组展开四级稳到爆
full_outline = outline_skeleton + "\n"
for i in range(0, len(level3_titles), 8):
batch = level3_titles[i:i+8]
batch_text = "\n".join(batch)
prompt2 = f"""你是一位招投标专家,请把下面这几个三级标题分别展开成 1018 个专业四级标题(格式必须是 1.1.1、1.1.2、……)。
只输出四级标题部分,不要重复三级标题本身。
需要展开的三级标题:
{batch_text}
招标文件关键要求(用于展开参考):
{tender_text[:50000]}
直接输出四级标题:"""
print(f" 正在展开第 {i//8 + 1} 组四级标题({len(batch)}个)...")
level4_text = call_llm([{"role": "user", "content": prompt2}],
temperature=0.2, max_tokens=20000)
full_outline += "\n" + level4_text + "\n"
time.sleep(2)
# 保存并返回
Path("output/四级目录.md").write_text(full_outline, encoding="utf-8")
print(f"超级四级目录生成成功!总计约 {len(full_outline)//2} 字(再也不怕超时了!)")
return full_outline
# ==================== 分批生成正文每批最多6个四级标题避免超上下文===================
def batch_fill_content(outline: str, tender_text: str) -> str:
level4_titles = [line.strip() for line in outline.split('\n')
if
re.match(r'^\d+\.\d+\.\d+、', line.strip()) or re.match(r'^[0-9]+\.[0-9]+\.[0-9]+ ', line.strip())]
print(f"共检测到 {len(level4_titles)} 个四级标题,将分批生成详细内容...")
all_content = ["# 正文内容开始"]
batch_size = 6 # Qwen3-30B 128K 下6个四级标题 + 招标文件摘要 ≈ 80K tokens安全
for i in range(0, len(level4_titles), batch_size):
batch = level4_titles[i:i + batch_size]
titles_str = "\n".join(batch)
prompt = f"""请为以下【{len(batch)}个四级标题】撰写极其详细、专业、可直接用于正式投标的正文内容。
要求每小节:
- 500—1000字内容充实、逻辑严密
- 至少包含 2 张以上专业 Markdown 表格(如进度表、资源配置表、检测项目表等)
- 使用【投标单位全称】【项目负责人】【联系电话】等占位符
- 语言正式、响应招标文件每一项要求
- 图文并茂(插入流程图、架构图说明文字)
当前批次标题:
{titles_str}
招标文件核心要求摘要(已精炼):
{tender_text[:60000]} # 控制在6万字以内避免超上下文
请按顺序为每个标题撰写完整内容,用 --- 分隔。"""
print(f"正在生成第 {i // batch_size + 1}/{len(level4_titles) // batch_size + 1} 批({len(batch)}个小节)...")
part = call_llm([{"role": "user", "content": prompt}], temperature=0.45, max_tokens=32000)
all_content.append(part)
time.sleep(2) # 礼貌等待避免打满GPU
final_content = "\n\n---\n\n".join(all_content)
Path("output/正文内容.md").write_text(final_content, encoding="utf-8")
print(f"所有正文生成完成!总计约 {len(final_content) // 2}")
return final_content
# ==================== 本地扩容到 5 万字+(美观填充)===================
def expand_to_50000_words(content: str) -> str:
current = len(content)
if current >= 100000:
return content
print(f"当前 {current // 2} 字,正在补充至 5 万字+...")
# 补充常见必备内容
appendix = """
### 六、售后服务体系
#### 6.1 服务承诺
我单位承诺7×24小时响应2小时内到达现场终身免费维护核心系统...
#### 6.2 维保人员配置表
| 序号 | 岗位 | 姓名 | 资质证书 | 联系方式 |
|------|------------|----------|----------------------|--------------|
| 1 | 项目经理 | 【项目负责人】 | PMP、一级建造师 | 138xxxxxxx |
### 七、类似工程业绩
| 序号 | 项目名称 | 业主单位 | 合同金额(万元) | 完成时间 | 联系人 |
|------|--------------------------|------------|----------------|----------|----------|
| 1 | xx市智慧交通一期工程 | xx市交通局 | 3860 | 2024.12 | 张工 |
"""
content += appendix * 15
return content
# ==================== 强制刷新 Word 目录(同前)===================
def update_word_toc(docx_path: str):
try:
import win32com.client as win32
import pythoncom
pythoncom.CoInitialize()
word = win32.Dispatch('Word.Application')
word.Visible = False
doc = word.Documents.Open(os.path.abspath(docx_path))
for toc in doc.TablesOfContents:
toc.Update()
doc.Save()
doc.Close()
word.Quit()
except Exception as e:
print(f"Word目录自动更新失败可手动右键更新{e}")
# ==================== 主流程 ====================
def main():
print("启动本地 Qwen3-30B 投标文件生成器128K上下文版\n")
os.makedirs("output", exist_ok=True)
# 1. 转换招标文件
tender_md = word_to_md(INPUT_WORD)
tender_text = Path(tender_md).read_text(encoding="utf-8")
# 2. 生成超级详细目录
outline = generate_full_outline(tender_md)
# 3. 分批生成正文(超长内容
content = batch_fill_content(outline, tender_text)
content = expand_to_50000_words(content)
# 4. 合成最终 Markdown
final_md = f"""# 【投标单位全称】
## {Path(INPUT_WORD).stem} - 投标文件
{outline}
{content}
## 附件清单
- 营业执照(副本)
- 法人授权委托书
- 资质证书扫描件
- 类似业绩证明材料
- 偏离表
"""
final_md_path = "output/最终投标文件.md"
Path(final_md_path).write_text(final_md, encoding="utf-8")
print(f"\n最终 Markdown 生成成功!总计约 {len(final_md) // 2}")
# 5. 转 Word三保险
print("正在转换为 Word 文档...")
success = False
pandoc_cmd = which("pandoc") or which("pandoc.exe")
if pandoc_cmd and os.path.exists(pandoc_cmd):
cmd = [pandoc_cmd, final_md_path, "-o", OUTPUT_WORD, "--reference-doc=template.docx"] if os.path.exists(
"template.docx") else [pandoc_cmd, final_md_path, "-o", OUTPUT_WORD]
if subprocess.run(cmd, capture_output=True).returncode == 0:
success = True
if not success:
print("Pandoc 失败,使用 python-docx 强制生成...")
doc = Document()
for line in final_md.split('\n'):
l = line.strip()
if l.startswith("# "):
doc.add_heading(l[2:], 0)
elif l.startswith("## "):
doc.add_heading(l[3:], 1)
elif l.startswith("### "):
doc.add_heading(l[4:], 2)
elif l.startswith("#### "):
doc.add_heading(l[5:], 3)
elif l:
doc.add_paragraph(l)
doc.save(OUTPUT_WORD)
update_word_toc(OUTPUT_WORD)
print(f"\n大功告成!投标文件已生成:")
print(f" {OUTPUT_WORD}")
print(f" 总字数约:{len(final_md) // 2}")
os.startfile(OUTPUT_WORD)
if __name__ == "__main__":
main()

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

66
config.py Normal file
View File

@ -0,0 +1,66 @@
from typing import List
import torch
from pydantic import BaseModel
# 设备配置
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
# 默认检测参数
DEFAULT_CONF = 0.25
DEFAULT_IOU = 0.5
DEFAULT_MIN_SIZE = 8
DEFAULT_POS_THRESH = 5
MODEL_CONFIGS = {
"安全施工模型": {
"model_path": "models/ppe_state_model/best.pt",
"types": ["novest", "nohelmet"],
"type_to_id": {"novest": 0, "nohelmet": 2},
"params": {
"enable_primary": True,
"primary_conf": 0.55,
"secondary_conf": 0.6,
"final_conf": 0.65,
"enable_multi_scale": True,
"multi_scales": [0.75, 1.0, 1.25],
"enable_secondary": True,
"slice_size": 512,
"overlap_ratio": 0.3,
"weight_primary": 0.4,
"weight_secondary": 0.6
}
},
"烟雾火灾模型": {
"model_path": "models/fire_smoke_model/best.pt",
"types": ["fire", "smoke"],
"type_to_id": {"fire": 0, "smoke": 1},
"params": {
"enable_primary": True,
"primary_conf": 0.99,
"secondary_conf": 0.99,
"final_conf": 0.99,
"enable_multi_scale": True,
"multi_scales": [0.75, 1.0, 1.25],
"enable_secondary": True,
"slice_size": 512,
"overlap_ratio": 0.3,
"weight_primary": 0.4,
"weight_secondary": 0.6
}
}
}
# SAHI自适应切片配置
SLICE_RULES = [
(12_000_000, (384, 0.35)),
(3_000_000, (512, 0.3)),
(0, (640, 0.25))
]
class DetectionResponse(BaseModel):
hasTarget: int
originalImgSize: List[int]
targets: List[dict]
processing_errors: List[str] = []

313
detect.py Normal file
View File

@ -0,0 +1,313 @@
import os
from collections import defaultdict
import cv2
import numpy as np
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from ultralytics import YOLO
from config import DEVICE, DEFAULT_IOU, DEFAULT_MIN_SIZE, DEFAULT_POS_THRESH, SLICE_RULES, DEFAULT_CONF
class YOLODetector:
def __init__(self, model_path, params, type_to_id):
# 加载YOLO模型
self.model = YOLO(model_path)
self.model.to(DEVICE)
self.class_names = self.model.names
self.type_to_id = type_to_id
self.params = params
self.enable_primary = params.get("enable_primary", True)
self.primary_conf = params.get("primary_conf", DEFAULT_CONF) # 初级检测阈值
self.secondary_conf = params.get("secondary_conf", DEFAULT_CONF) # 次级检测阈值
self.final_conf = params.get("final_conf", DEFAULT_CONF) # 最终展示阈值
# SAHI模型
self.sahi_model = None
if params["enable_secondary"]:
self.sahi_model = AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path=model_path,
confidence_threshold=self.secondary_conf,
device=DEVICE
)
# 统计
self.stats = defaultdict(int)
def get_adaptive_slice(self, total_pixels):
"""自适应切片参数"""
for pixel_thresh, (size, overlap) in SLICE_RULES:
if total_pixels > pixel_thresh:
return size, overlap
return self.params["slice_size"], self.params["overlap_ratio"]
def multi_scale_detect(self, img_path):
"""多尺度检测(使用模型专属初级阈值)"""
detections = []
img = cv2.imread(img_path)
h, w = img.shape[:2]
for scale in self.params["multi_scales"]:
if scale == 1.0:
# 原尺度检测
results = self.model(
img_path,
conf=self.primary_conf, # 模型专属初级阈值
device=DEVICE,
classes=self.target_ids,
verbose=False
)
else:
# 缩放检测
nw, nh = int(w * scale), int(h * scale)
scaled_img = cv2.resize(img, (nw, nh))
temp_path = f"temp_scale_{scale}.jpg"
cv2.imwrite(temp_path, scaled_img)
results = self.model(
temp_path,
conf=self.primary_conf, # 模型专属初级阈值
device=DEVICE,
classes=self.target_ids,
verbose=False
)
os.remove(temp_path)
# 解析结果核心修复增加对result.boxes为None的判断
for result in results:
# 检查boxes是否存在且非空
if result.boxes is None:
continue
for box in result.boxes:
bbox = box.xyxy[0].tolist()
if scale != 1.0:
bbox = [coord / scale for coord in bbox]
detections.append({
"box": bbox,
"conf": box.conf[0].item(),
"class": box.cls[0].item(),
"class_name": self.class_names[int(box.cls[0])],
"source": "primary"
})
return detections
def primary_detect(self, img_path):
"""初次检测(使用模型专属初级阈值)- 新增enable_primary判断"""
# 新增:如果禁用一级检测,直接返回空列表
if not self.enable_primary:
self.stats["primary"] = 0
print(" 一级检测已禁用,跳过初级检测")
return []
if self.params["enable_multi_scale"]:
detections = self.multi_scale_detect(img_path)
else:
results = self.model(
img_path,
conf=self.primary_conf, # 模型专属初级阈值
device=DEVICE,
classes=self.target_ids,
verbose=False
)
# 解析结果核心修复增加对result.boxes为None的判断
detections = []
for result in results:
# 检查boxes是否存在且非空
if result.boxes is None:
continue
for box in result.boxes:
detections.append({
"box": box.xyxy[0].tolist(),
"conf": box.conf[0].item(),
"class": box.cls[0].item(),
"class_name": self.class_names[int(box.cls[0])],
"source": "primary"
})
self.stats["primary"] = len(detections)
return detections
def secondary_detect(self, img_path):
"""SAHI切片检测已在初始化时使用模型专属次级阈值"""
if not self.params["enable_secondary"] or not self.sahi_model:
return []
img = cv2.imread(img_path)
h, w = img.shape[:2]
total_pixels = w * h
slice_size, overlap = self.get_adaptive_slice(total_pixels)
# SAHI切片预测
sliced_results = get_sliced_prediction(
img_path,
self.sahi_model,
slice_height=slice_size,
slice_width=slice_size,
overlap_height_ratio=overlap,
overlap_width_ratio=overlap,
verbose=0
)
detections = []
for obj in sliced_results.object_prediction_list:
if self.target_ids and obj.category.id not in self.target_ids:
continue
bbox = obj.bbox.to_xyxy()
bw, bh = bbox[2] - bbox[0], bbox[3] - bbox[1]
if bw >= DEFAULT_MIN_SIZE and bh >= DEFAULT_MIN_SIZE:
detections.append({
"box": bbox,
"conf": obj.score.value,
"class": obj.category.id,
"class_name": obj.category.name,
"source": "secondary"
})
self.stats["secondary"] = len(detections)
return detections
@staticmethod
def calculate_iou(box1, box2):
"""计算IoU"""
x11, y11, x21, y21 = box1
x12, y12, x22, y22 = box2
inter_x1 = max(x11, x12)
inter_y1 = max(y11, y12)
inter_x2 = min(x21, x22)
inter_y2 = min(y21, y22)
inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
area1 = (x21 - x11) * (y21 - y11)
area2 = (x22 - x12) * (y22 - y12)
union_area = area1 + area2 - inter_area
return inter_area / union_area if union_area > 0 else 0
def merge_detections(self, primary_dets, secondary_dets):
"""融合检测结果"""
if not primary_dets:
return secondary_dets
if not secondary_dets:
return primary_dets
# 加权置信度
all_dets = []
for det in primary_dets:
det["weighted_conf"] = det["conf"] * self.params["weight_primary"]
all_dets.append(det)
for det in secondary_dets:
det["weighted_conf"] = det["conf"] * self.params["weight_secondary"]
all_dets.append(det)
# 按类别分组融合
class_groups = defaultdict(list)
for det in all_dets:
class_groups[det["class"]].append(det)
merged = []
for cls_id, cls_dets in class_groups.items():
cls_dets.sort(key=lambda x: x["weighted_conf"], reverse=True)
suppressed = [False] * len(cls_dets)
for i in range(len(cls_dets)):
if suppressed[i]:
continue
merged.append(cls_dets[i])
for j in range(i + 1, len(cls_dets)):
if not suppressed[j] and self.calculate_iou(cls_dets[i]["box"], cls_dets[j]["box"]) > DEFAULT_IOU:
suppressed[j] = True
self.stats["merged"] = len(merged)
return merged
def post_process(self, detections):
"""后处理(使用模型专属最终阈值)"""
# 置信度过滤:模型专属最终阈值
filtered = [det for det in detections if det["conf"] >= self.final_conf]
# 位置去重
final_dets = []
for curr_det in filtered:
curr_cx = (curr_det["box"][0] + curr_det["box"][2]) / 2
curr_cy = (curr_det["box"][1] + curr_det["box"][3]) / 2
curr_cls = curr_det["class"]
duplicate = False
for idx, exist_det in enumerate(final_dets):
if exist_det["class"] != curr_cls:
continue
exist_cx = (exist_det["box"][0] + exist_det["box"][2]) / 2
exist_cy = (exist_det["box"][1] + exist_det["box"][3]) / 2
dist = np.sqrt((curr_cx - exist_cx) **2 + (curr_cy - exist_cy)** 2)
if dist < DEFAULT_POS_THRESH:
duplicate = True
if curr_det["conf"] > exist_det["conf"]:
final_dets[idx] = curr_det
break
if not duplicate:
final_dets.append(curr_det)
self.stats["final"] = len(final_dets)
return final_dets
def format_results(self, detections):
"""格式化结果"""
formatted = []
for det in detections:
x1, y1, x2, y2 = det["box"]
formatted.append({
"type": det["class_name"],
"size": [int(round(x2 - x1)), int(round(y2 - y1))],
"leftTopPoint": [int(round(x1)), int(round(y1))],
"score": round(det["conf"], 4),
})
return formatted
def get_detection_stats(self):
"""获取检测统计信息"""
return dict(self.stats)
def detect(self, img_path, target_types=None):
"""完整检测流程"""
# 重置统计
self.stats = defaultdict(int)
# 设置目标类别
if target_types:
self.target_ids = [self.type_to_id[cls] for cls in target_types if cls in self.type_to_id]
else:
self.target_ids = None
# 执行检测
primary_dets = self.primary_detect(img_path)
print(f" 初级检测后: {self.stats['primary']} 个目标")
if self.params["enable_secondary"]:
secondary_dets = self.secondary_detect(img_path)
print(f" 次级检测后: {self.stats['secondary']} 个目标")
merged_dets = self.merge_detections(primary_dets, secondary_dets)
print(f" 融合去重后: {self.stats['merged']} 个目标")
else:
merged_dets = primary_dets
print(f" 次级检测未启用")
# 后处理
processed_dets = self.post_process(merged_dets)
print(f" 过滤低置信度后: {self.stats['final']} 个目标")
print(" 最终检测目标详情:")
for idx, det in enumerate(processed_dets, 1):
print(f" 目标{idx} - 类型:{det['class_name']},置信度:{det['conf']:.4f}")
return self.format_results(processed_dets)

125
main.py Normal file
View File

@ -0,0 +1,125 @@
import io
import os
import tempfile
from contextlib import asynccontextmanager
import requests
import uvicorn
from PIL import Image
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, HttpUrl
from config import DetectionResponse
from process import detect_large_image_from_url
# 全局检测管理器
detector_manager = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global detector_manager
try:
from manager import UnifiedDetectionManager
detector_manager = UnifiedDetectionManager()
print("检测管理器初始化成功")
except Exception as e:
print(f"初始化失败:{str(e)}")
raise
yield
app = FastAPI(lifespan=lifespan, title="目标检测API", version="1.0.0")
# 配置跨域请求
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class DetectionRequest(BaseModel):
type: str
url: HttpUrl
class DetectionProcessRequest(BaseModel):
url: HttpUrl
@app.post("/detect_image", response_model=DetectionResponse)
async def run_detection_image(request: DetectionRequest):
# 解析检测类型
requested_types = {t.strip().lower() for t in request.type.split(',') if t.strip()}
print(f"请求的检测类型: {requested_types}")
if not requested_types:
raise HTTPException(status_code=400, detail="未指定检测类型")
# 下载图片
try:
response = requests.get(str(request.url), timeout=15)
response.raise_for_status()
# 获取图片尺寸
with Image.open(io.BytesIO(response.content)) as img:
img_size = [img.width, img.height]
# 创建临时文件
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
temp_file.write(response.content)
temp_path = temp_file.name
except Exception as e:
raise HTTPException(status_code=400, detail=f"图片处理失败:{str(e)}")
# 执行检测
results = []
errors = []
try:
detection_results = detector_manager.detect(temp_path, ",".join(requested_types))
if detection_results:
results = detection_results
except Exception as e:
errors.append(f"检测失败:{str(e)}")
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
return {
"hasTarget": 1 if results else 0,
"originalImgSize": img_size,
"targets": results,
"processing_errors": errors
}
@app.post("/detect_process", response_model=DetectionResponse)
async def run_detection_process(request: DetectionProcessRequest):
return detect_large_image_from_url(str(request.url))
@app.get("/supported_types")
async def get_supported_types():
if detector_manager:
info = detector_manager.get_available_info()
return {
"supported_types": info["supported_types"],
}
return {"supported_types": []}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--reload", action="store_true")
args = parser.parse_args()
uvicorn.run("main:app", host=args.host, port=args.port, reload=args.reload)

116
manager.py Normal file
View File

@ -0,0 +1,116 @@
import os
from collections import defaultdict
from config import MODEL_CONFIGS
from detect import YOLODetector
class UnifiedDetectionManager:
"""统一检测管理器"""
def __init__(self):
self.detectors = {} # 检测器实例
self.type_to_model = {} # 类别到模型映射
self.loaded_models = [] # 已加载模型
self.type_to_id = {} # 全局类别ID映射
self._load_models()
def _load_models(self):
"""加载所有模型"""
if not MODEL_CONFIGS:
raise ValueError("模型配置为空")
for model_name, config in MODEL_CONFIGS.items():
try:
model_path = config["model_path"]
if not os.path.exists(model_path):
print(f"跳过 {model_name}: 模型文件不存在 - {model_path}")
continue
# 创建检测器自动传递新增的enable_primary配置
detector = YOLODetector(
model_path=model_path,
params=config["params"],
type_to_id=config["type_to_id"]
)
# 保存状态
self.detectors[model_name] = detector
self.loaded_models.append(model_name)
# 建立映射
for det_type in config["types"]:
det_type_lower = det_type.lower()
if det_type_lower in self.type_to_model:
print(f"警告: 类别 '{det_type}' 映射冲突")
self.type_to_model[det_type_lower] = model_name
self.type_to_id[det_type_lower] = config["type_to_id"][det_type_lower]
print(f"加载成功: {model_name}")
except Exception as e:
print(f"加载失败 {model_name}: {str(e)}")
continue
print(f"模型加载完成: {len(self.loaded_models)}/{len(MODEL_CONFIGS)}")
print(f"支持类别: {list(self.type_to_model.keys())}")
def parse_types(self, types_str):
"""解析检测类型"""
if not types_str:
raise ValueError("检测类型为空")
# 清理输入
requested_types = list(set(t.strip().lower() for t in types_str.split(',') if t.strip()))
# 按模型分组
model_type_map = defaultdict(list)
for det_type in requested_types:
if det_type in self.type_to_model:
model_name = self.type_to_model[det_type]
model_type_map[model_name].append(det_type)
else:
print(f"忽略未知类别: {det_type}")
if not model_type_map:
raise ValueError("无有效检测类别")
return model_type_map
def detect(self, img_path, detection_types):
"""执行检测"""
if not os.path.exists(img_path):
raise FileNotFoundError(f"图像不存在: {img_path}")
# 解析类型
model_type_map = self.parse_types(detection_types)
# 执行检测自动适配enable_primary配置
all_results = []
for model_name, target_types in model_type_map.items():
if model_name not in self.detectors:
continue
print(f"检测: {model_name} -> {target_types}")
try:
results = self.detectors[model_name].detect(img_path, target_types)
all_results.extend(results)
# 获取详细统计信息
stats = self.detectors[model_name].get_detection_stats()
print(f" {model_name}详细统计: {stats}")
except Exception as e:
print(f"检测失败 {model_name}: {str(e)}")
print(f"检测完成: 总共 {len(all_results)} 个结果")
return all_results
def get_available_info(self):
"""获取可用信息"""
return {
"loaded_models": self.loaded_models,
"supported_types": list(self.type_to_model.keys()),
"type_to_model": self.type_to_model
}

Binary file not shown.

Binary file not shown.

Binary file not shown.

180
process.py Normal file
View File

@ -0,0 +1,180 @@
import os
import tempfile
from typing import List
from urllib.parse import urlparse
import cv2
import requests
from pydantic import BaseModel
from tqdm import tqdm
from ultralytics import YOLO
# 定义返回值模型
class DetectionResponse(BaseModel):
hasTarget: int
originalImgSize: List[int]
targets: List[dict]
processing_errors: List[str] = []
def download_large_file(url, chunk_size=1024 * 1024):
"""下载大型文件到临时文件、返回临时文件路径"""
try:
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
file_size = int(response.headers.get('Content-Length', 0))
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') # 适配png格式
temp_file_path = temp_file.name
temp_file.close()
with open(temp_file_path, 'wb') as f, tqdm(
total=file_size, unit='B', unit_scale=True,
desc=f"下载 {os.path.basename(urlparse(url).path)}"
) as pbar:
for chunk in response.iter_content(chunk_size=chunk_size):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
return temp_file_path
except Exception as e:
error_msg = f"下载失败: {str(e)}"
print(error_msg)
if 'temp_file_path' in locals():
try:
os.remove(temp_file_path)
except:
pass
raise Exception(error_msg)
def slice_large_image(image_path, slice_size=1024, overlap=100):
"""切分大图为切片、返回切片数据和位置信息"""
img = cv2.imread(image_path)
if img is None:
raise ValueError(f"无法读取图像: {image_path}")
h, w = img.shape[:2]
step = slice_size - overlap
num_rows = (h + step - 1) // step
num_cols = (w + step - 1) // step
slices = []
for i in range(num_rows):
for j in range(num_cols):
y1 = i * step
x1 = j * step
y2 = min(y1 + slice_size, h)
x2 = min(x1 + slice_size, w)
if y2 - y1 < slice_size:
y1 = max(0, y2 - slice_size)
if x2 - x1 < slice_size:
x1 = max(0, x2 - slice_size)
slice_img = img[y1:y2, x1:x2]
slices.append((x1, y1, slice_img))
return slices, (h, w)
def extract_detection_info(result, slice_offset_x, slice_offset_y):
"""从YOLO OBB结果中提取检测框信息修正宽高计算"""
detections = []
if result.obb is not None and len(result.obb) > 0:
obb_data = result.obb
obb_xyxy = obb_data.xyxy.cpu().numpy()
classes = obb_data.cls.cpu().numpy()
confidences = obb_data.conf.cpu().numpy()
for i in range(len(obb_data)):
x1_slice, y1_slice, x2_slice, y2_slice = obb_xyxy[i]
# 计算实际宽高x方向为宽y方向为高
width = x2_slice - x1_slice
height = y2_slice - y1_slice
# 转换为全局坐标
x1_global = x1_slice + slice_offset_x
y1_global = y1_slice + slice_offset_y
cls_id = int(classes[i])
confidence = float(confidences[i])
class_name = result.names[cls_id]
detection_info = {
"type": class_name,
"size": [int(round(width)), int(round(height))],
"leftTopPoint": [int(round(x1_global)), int(round(y1_global))],
"score": round(confidence, 4)
}
detections.append(detection_info)
return detections
def detect_large_image_from_url(image_url: str, slice_size: int = 1024, overlap: int = 100) -> DetectionResponse:
"""
封装后的检测方法从图片URL处理大图、返回DetectionResponse对象
"""
# 动态拼接固定model_path当前文件同级目录下
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, "models", "solor_bracket_model", "best.pt")
processing_errors = []
all_detections = []
original_size = [0, 0]
try:
# 验证模型文件是否存在
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在:{model_path}")
# 下载图像
temp_file_path = download_large_file(image_url)
try:
# 切分图像
slices_info, (h, w) = slice_large_image(temp_file_path, slice_size, overlap)
original_size = [w, h]
print(f"完成切片: 共 {len(slices_info)} 个切片")
# 加载模型并预测
model = YOLO(model_path)
print("开始逐张预测切片...")
for i, (x1, y1, slice_img) in enumerate(slices_info, 1):
print(f"预测第 {i}/{len(slices_info)} 个切片")
result = model(slice_img, conf=0.5, verbose=False)[0]
slice_detections = extract_detection_info(result, x1, y1)
all_detections.extend(slice_detections)
print(f" 本切片检测到 {len(slice_detections)} 个目标")
finally:
# 确保临时文件删除
if os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
print("临时文件已删除")
except Exception as e:
error_msg = f"删除临时文件失败: {str(e)}"
print(error_msg)
processing_errors.append(error_msg)
except Exception as e:
# 捕获所有异常并记录
error_msg = str(e)
processing_errors.append(error_msg)
print(f"处理异常: {error_msg}")
# 构建并返回DetectionResponse对象
return DetectionResponse(
hasTarget=1 if len(all_detections) > 0 else 0,
originalImgSize=original_size,
targets=all_detections,
processing_errors=processing_errors
)