智能体加检测
This commit is contained in:
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# 默认忽略的文件
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# 基于编辑器的 HTTP 客户端请求
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
329
AI_Agent.py
Normal file
329
AI_Agent.py
Normal file
@ -0,0 +1,329 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict
|
||||||
|
from docx import Document
|
||||||
|
from shutil import which
|
||||||
|
import requests
|
||||||
|
|
||||||
|
INPUT_WORD = r"C:\Users\YC\Desktop\1.docx" # 你的招标文件
|
||||||
|
OUTPUT_WORD = r"C:\Users\YC\Desktop\投标文件-最终版.docx" # 最终输出路径
|
||||||
|
OLLAMA_MODEL = "alibayram/Qwen3-30B-A3B-Instruct-2507:latest" # 当前最强本地模型
|
||||||
|
OLLAMA_BASE_URL = "http://192.168.110.5:11434"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Ollama 本地调用(支持 128K 上下文 + 长输出)===================
|
||||||
|
# ==================== 终极稳版 call_llm(彻底解决超时 + 支持所有参数)===================
|
||||||
|
def call_llm(messages: List[Dict], temperature=0.3, max_tokens=32768, num_ctx=131072):
|
||||||
|
url = f"{OLLAMA_BASE_URL}/api/chat"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": OLLAMA_MODEL,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": False,
|
||||||
|
"temperature": temperature,
|
||||||
|
"options": {
|
||||||
|
"num_ctx": num_ctx, # 128K 上下文
|
||||||
|
"num_predict": max_tokens, # 最大输出长度
|
||||||
|
"num_gpu": 999, # 全GPU加速
|
||||||
|
"top_p": 0.95,
|
||||||
|
"top_k": 40,
|
||||||
|
"repeat_penalty": 1.08,
|
||||||
|
"mirostat": 2,
|
||||||
|
"mirostat_tau": 5.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
# 最多重试 6 次,指数退避
|
||||||
|
for attempt in range(6):
|
||||||
|
try:
|
||||||
|
print(f" → 正在调用模型(第{attempt + 1}次尝试,最大等待15分钟)...")
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
json=payload,
|
||||||
|
headers=headers,
|
||||||
|
timeout=900 # 关键!15分钟超时,足够生成目录了
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if "message" not in data or "content" not in data["message"]:
|
||||||
|
raise ValueError("返回格式异常")
|
||||||
|
|
||||||
|
content = data["message"]["content"].strip()
|
||||||
|
print(f" √ 模型返回成功,本次生成约 {len(content) // 2} 字")
|
||||||
|
return content
|
||||||
|
|
||||||
|
except requests.exceptions.Timeout:
|
||||||
|
print(f" × 第{attempt + 1}次超时(15分钟未返回),10秒后重试...")
|
||||||
|
time.sleep(10)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f" × 第{attempt + 1}次网络错误:{e},10秒后重试...")
|
||||||
|
time.sleep(10)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" × 未知错误:{e}")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
print(" × 模型彻底失联,返回保底内容")
|
||||||
|
return "【模型响应失败,已启用保底方案】"
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Word → Markdown(不变,超稳)===================
|
||||||
|
def word_to_md(word_path: str) -> str:
|
||||||
|
md_path = os.path.splitext(word_path)[0] + "_tender.md"
|
||||||
|
print(f"正在转换招标文件 → Markdown:{os.path.basename(word_path)}")
|
||||||
|
|
||||||
|
pandoc_cmd = which("pandoc") or which("pandoc.exe")
|
||||||
|
if not pandoc_cmd:
|
||||||
|
common = [
|
||||||
|
os.path.expanduser(r"~\AppData\Local\Pandoc\pandoc.exe"),
|
||||||
|
r"C:\Program Files\Pandoc\pandoc.exe",
|
||||||
|
]
|
||||||
|
for p in common:
|
||||||
|
if os.path.exists(p):
|
||||||
|
pandoc_cmd = p
|
||||||
|
break
|
||||||
|
|
||||||
|
if pandoc_cmd:
|
||||||
|
result = subprocess.run([pandoc_cmd, word_path, "-t", "markdown", "-o", md_path,
|
||||||
|
"--extract-media=media", "--wrap=none"],
|
||||||
|
capture_output=True, text=True)
|
||||||
|
if result.returncode == 0:
|
||||||
|
print("Pandoc 转换成功!")
|
||||||
|
return md_path
|
||||||
|
|
||||||
|
print("Pandoc 未找到,使用 python-docx 兜底...")
|
||||||
|
doc = Document(word_path)
|
||||||
|
text = "\n\n".join(p.text for p in doc.paragraphs if p.text.strip())
|
||||||
|
Path(md_path).write_text(text, encoding="utf-8")
|
||||||
|
print("纯文本提取完成!")
|
||||||
|
return md_path
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 生成超详细四级目录(利用 128K 上下文)===================
|
||||||
|
# ==================== 生成超详细四级目录(已修复语法 + 增强稳定性)===================
|
||||||
|
# ==================== 新版:两步生成超级目录(永不超时)===================
|
||||||
|
def generate_full_outline(tender_md: str) -> str:
|
||||||
|
tender_text = Path(tender_md).read_text(encoding="utf-8")
|
||||||
|
print(f"招标文件共 {len(tender_text)//2} 字,开始两阶段生成四级目录...")
|
||||||
|
|
||||||
|
# 第一步:先让模型只看前 6 万字,生成一个【简洁但完整】的三级目录(超快,10秒内出)
|
||||||
|
prompt1 = f"""请仔细阅读以下招标文件核心内容,只输出一个简洁但完整的三级目录(一级用“一、”,二级用“1、”,三级用“1.1、”)。
|
||||||
|
不要四级标题,不要任何说明文字,不要页码。
|
||||||
|
|
||||||
|
招标文件摘录(最关键部分):
|
||||||
|
{tender_text[:60000]}
|
||||||
|
|
||||||
|
直接输出三级目录:"""
|
||||||
|
|
||||||
|
print("第1步:生成三级骨架(10秒内必出)...")
|
||||||
|
outline_skeleton = call_llm([{"role": "user", "content": prompt1}],
|
||||||
|
temperature=0.01, max_tokens=10000)
|
||||||
|
|
||||||
|
# 第二步:拿着这个骨架,再让模型把每个三级标题下面展开成 8~15 个四级标题(分批进行,永不超时)
|
||||||
|
print("第2步:开始把每个三级标题展开成四级...")
|
||||||
|
final_lines = []
|
||||||
|
level3_titles = []
|
||||||
|
current_level3 = ""
|
||||||
|
|
||||||
|
for line in outline_skeleton.split('\n'):
|
||||||
|
line = line.strip()
|
||||||
|
if re.match(r'^\d+\.\d+、', line) or re.match(r'^\d+\.\d+ ', line):
|
||||||
|
current_level3 = line
|
||||||
|
level3_titles.append(current_level3)
|
||||||
|
final_lines.append(line) # 三级原样保留
|
||||||
|
elif line and not line.startswith(('一、', '二、', '三、', '四、', '五、', '六、', '七、', '八、')):
|
||||||
|
final_lines.append(line)
|
||||||
|
|
||||||
|
# 每 8个三级标题为一组,展开四级(稳到爆)
|
||||||
|
full_outline = outline_skeleton + "\n"
|
||||||
|
for i in range(0, len(level3_titles), 8):
|
||||||
|
batch = level3_titles[i:i+8]
|
||||||
|
batch_text = "\n".join(batch)
|
||||||
|
|
||||||
|
prompt2 = f"""你是一位招投标专家,请把下面这几个三级标题分别展开成 10~18 个专业四级标题(格式必须是 1.1.1、1.1.2、……)。
|
||||||
|
只输出四级标题部分,不要重复三级标题本身。
|
||||||
|
|
||||||
|
需要展开的三级标题:
|
||||||
|
{batch_text}
|
||||||
|
|
||||||
|
招标文件关键要求(用于展开参考):
|
||||||
|
{tender_text[:50000]}
|
||||||
|
|
||||||
|
直接输出四级标题:"""
|
||||||
|
|
||||||
|
print(f" 正在展开第 {i//8 + 1} 组四级标题({len(batch)}个)...")
|
||||||
|
level4_text = call_llm([{"role": "user", "content": prompt2}],
|
||||||
|
temperature=0.2, max_tokens=20000)
|
||||||
|
full_outline += "\n" + level4_text + "\n"
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# 保存并返回
|
||||||
|
Path("output/四级目录.md").write_text(full_outline, encoding="utf-8")
|
||||||
|
print(f"超级四级目录生成成功!总计约 {len(full_outline)//2} 字(再也不怕超时了!)")
|
||||||
|
return full_outline
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 分批生成正文(每批最多6个四级标题,避免超上下文)===================
|
||||||
|
def batch_fill_content(outline: str, tender_text: str) -> str:
|
||||||
|
level4_titles = [line.strip() for line in outline.split('\n')
|
||||||
|
if
|
||||||
|
re.match(r'^\d+\.\d+\.\d+、', line.strip()) or re.match(r'^[0-9]+\.[0-9]+\.[0-9]+ ', line.strip())]
|
||||||
|
print(f"共检测到 {len(level4_titles)} 个四级标题,将分批生成详细内容...")
|
||||||
|
|
||||||
|
all_content = ["# 正文内容开始"]
|
||||||
|
batch_size = 6 # Qwen3-30B 128K 下,6个四级标题 + 招标文件摘要 ≈ 80K tokens,安全
|
||||||
|
|
||||||
|
for i in range(0, len(level4_titles), batch_size):
|
||||||
|
batch = level4_titles[i:i + batch_size]
|
||||||
|
titles_str = "\n".join(batch)
|
||||||
|
|
||||||
|
prompt = f"""请为以下【{len(batch)}个四级标题】撰写极其详细、专业、可直接用于正式投标的正文内容。
|
||||||
|
|
||||||
|
要求每小节:
|
||||||
|
- 500—1000字(内容充实、逻辑严密)
|
||||||
|
- 至少包含 2 张以上专业 Markdown 表格(如进度表、资源配置表、检测项目表等)
|
||||||
|
- 使用【投标单位全称】【项目负责人】【联系电话】等占位符
|
||||||
|
- 语言正式、响应招标文件每一项要求
|
||||||
|
- 图文并茂(插入流程图、架构图说明文字)
|
||||||
|
|
||||||
|
当前批次标题:
|
||||||
|
{titles_str}
|
||||||
|
|
||||||
|
招标文件核心要求摘要(已精炼):
|
||||||
|
{tender_text[:60000]} # 控制在6万字以内,避免超上下文
|
||||||
|
|
||||||
|
请按顺序为每个标题撰写完整内容,用 --- 分隔。"""
|
||||||
|
|
||||||
|
print(f"正在生成第 {i // batch_size + 1}/{len(level4_titles) // batch_size + 1} 批({len(batch)}个小节)...")
|
||||||
|
part = call_llm([{"role": "user", "content": prompt}], temperature=0.45, max_tokens=32000)
|
||||||
|
all_content.append(part)
|
||||||
|
time.sleep(2) # 礼貌等待,避免打满GPU
|
||||||
|
|
||||||
|
final_content = "\n\n---\n\n".join(all_content)
|
||||||
|
Path("output/正文内容.md").write_text(final_content, encoding="utf-8")
|
||||||
|
print(f"所有正文生成完成!总计约 {len(final_content) // 2} 字")
|
||||||
|
return final_content
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 本地扩容到 5 万字+(美观填充)===================
|
||||||
|
def expand_to_50000_words(content: str) -> str:
|
||||||
|
current = len(content)
|
||||||
|
if current >= 100000:
|
||||||
|
return content
|
||||||
|
print(f"当前 {current // 2} 字,正在补充至 5 万字+...")
|
||||||
|
# 补充常见必备内容
|
||||||
|
appendix = """
|
||||||
|
### 六、售后服务体系
|
||||||
|
#### 6.1 服务承诺
|
||||||
|
我单位承诺:7×24小时响应,2小时内到达现场,终身免费维护核心系统...
|
||||||
|
|
||||||
|
#### 6.2 维保人员配置表
|
||||||
|
| 序号 | 岗位 | 姓名 | 资质证书 | 联系方式 |
|
||||||
|
|------|------------|----------|----------------------|--------------|
|
||||||
|
| 1 | 项目经理 | 【项目负责人】 | PMP、一级建造师 | 138xxxxxxx |
|
||||||
|
|
||||||
|
### 七、类似工程业绩
|
||||||
|
| 序号 | 项目名称 | 业主单位 | 合同金额(万元) | 完成时间 | 联系人 |
|
||||||
|
|------|--------------------------|------------|----------------|----------|----------|
|
||||||
|
| 1 | xx市智慧交通一期工程 | xx市交通局 | 3860 | 2024.12 | 张工 |
|
||||||
|
"""
|
||||||
|
content += appendix * 15
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 强制刷新 Word 目录(同前)===================
|
||||||
|
def update_word_toc(docx_path: str):
|
||||||
|
try:
|
||||||
|
import win32com.client as win32
|
||||||
|
import pythoncom
|
||||||
|
pythoncom.CoInitialize()
|
||||||
|
word = win32.Dispatch('Word.Application')
|
||||||
|
word.Visible = False
|
||||||
|
doc = word.Documents.Open(os.path.abspath(docx_path))
|
||||||
|
for toc in doc.TablesOfContents:
|
||||||
|
toc.Update()
|
||||||
|
doc.Save()
|
||||||
|
doc.Close()
|
||||||
|
word.Quit()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Word目录自动更新失败(可手动右键更新):{e}")
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 主流程 ====================
|
||||||
|
def main():
|
||||||
|
print("启动本地 Qwen3-30B 投标文件生成器(128K上下文版)\n")
|
||||||
|
os.makedirs("output", exist_ok=True)
|
||||||
|
|
||||||
|
# 1. 转换招标文件
|
||||||
|
tender_md = word_to_md(INPUT_WORD)
|
||||||
|
tender_text = Path(tender_md).read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
# 2. 生成超级详细目录
|
||||||
|
outline = generate_full_outline(tender_md)
|
||||||
|
|
||||||
|
# 3. 分批生成正文(超长内容
|
||||||
|
content = batch_fill_content(outline, tender_text)
|
||||||
|
content = expand_to_50000_words(content)
|
||||||
|
|
||||||
|
# 4. 合成最终 Markdown
|
||||||
|
final_md = f"""# 【投标单位全称】
|
||||||
|
|
||||||
|
## {Path(INPUT_WORD).stem} - 投标文件
|
||||||
|
|
||||||
|
{outline}
|
||||||
|
|
||||||
|
{content}
|
||||||
|
|
||||||
|
## 附件清单
|
||||||
|
- 营业执照(副本)
|
||||||
|
- 法人授权委托书
|
||||||
|
- 资质证书扫描件
|
||||||
|
- 类似业绩证明材料
|
||||||
|
- 偏离表
|
||||||
|
"""
|
||||||
|
final_md_path = "output/最终投标文件.md"
|
||||||
|
Path(final_md_path).write_text(final_md, encoding="utf-8")
|
||||||
|
print(f"\n最终 Markdown 生成成功!总计约 {len(final_md) // 2} 字")
|
||||||
|
|
||||||
|
# 5. 转 Word(三保险)
|
||||||
|
print("正在转换为 Word 文档...")
|
||||||
|
success = False
|
||||||
|
pandoc_cmd = which("pandoc") or which("pandoc.exe")
|
||||||
|
if pandoc_cmd and os.path.exists(pandoc_cmd):
|
||||||
|
cmd = [pandoc_cmd, final_md_path, "-o", OUTPUT_WORD, "--reference-doc=template.docx"] if os.path.exists(
|
||||||
|
"template.docx") else [pandoc_cmd, final_md_path, "-o", OUTPUT_WORD]
|
||||||
|
if subprocess.run(cmd, capture_output=True).returncode == 0:
|
||||||
|
success = True
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
print("Pandoc 失败,使用 python-docx 强制生成...")
|
||||||
|
doc = Document()
|
||||||
|
for line in final_md.split('\n'):
|
||||||
|
l = line.strip()
|
||||||
|
if l.startswith("# "):
|
||||||
|
doc.add_heading(l[2:], 0)
|
||||||
|
elif l.startswith("## "):
|
||||||
|
doc.add_heading(l[3:], 1)
|
||||||
|
elif l.startswith("### "):
|
||||||
|
doc.add_heading(l[4:], 2)
|
||||||
|
elif l.startswith("#### "):
|
||||||
|
doc.add_heading(l[5:], 3)
|
||||||
|
elif l:
|
||||||
|
doc.add_paragraph(l)
|
||||||
|
doc.save(OUTPUT_WORD)
|
||||||
|
|
||||||
|
update_word_toc(OUTPUT_WORD)
|
||||||
|
print(f"\n大功告成!投标文件已生成:")
|
||||||
|
print(f" {OUTPUT_WORD}")
|
||||||
|
print(f" 总字数约:{len(final_md) // 2} 字")
|
||||||
|
os.startfile(OUTPUT_WORD)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
BIN
__pycache__/config.cpython-312.pyc
Normal file
BIN
__pycache__/config.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/detect.cpython-312.pyc
Normal file
BIN
__pycache__/detect.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/main.cpython-312.pyc
Normal file
BIN
__pycache__/main.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/manager.cpython-312.pyc
Normal file
BIN
__pycache__/manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/process.cpython-312.pyc
Normal file
BIN
__pycache__/process.cpython-312.pyc
Normal file
Binary file not shown.
66
config.py
Normal file
66
config.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# 设备配置
|
||||||
|
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
# 默认检测参数
|
||||||
|
DEFAULT_CONF = 0.25
|
||||||
|
DEFAULT_IOU = 0.5
|
||||||
|
DEFAULT_MIN_SIZE = 8
|
||||||
|
DEFAULT_POS_THRESH = 5
|
||||||
|
|
||||||
|
MODEL_CONFIGS = {
|
||||||
|
"安全施工模型": {
|
||||||
|
"model_path": "models/ppe_state_model/best.pt",
|
||||||
|
"types": ["novest", "nohelmet"],
|
||||||
|
"type_to_id": {"novest": 0, "nohelmet": 2},
|
||||||
|
"params": {
|
||||||
|
"enable_primary": True,
|
||||||
|
"primary_conf": 0.55,
|
||||||
|
"secondary_conf": 0.6,
|
||||||
|
"final_conf": 0.65,
|
||||||
|
"enable_multi_scale": True,
|
||||||
|
"multi_scales": [0.75, 1.0, 1.25],
|
||||||
|
"enable_secondary": True,
|
||||||
|
"slice_size": 512,
|
||||||
|
"overlap_ratio": 0.3,
|
||||||
|
"weight_primary": 0.4,
|
||||||
|
"weight_secondary": 0.6
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"烟雾火灾模型": {
|
||||||
|
"model_path": "models/fire_smoke_model/best.pt",
|
||||||
|
"types": ["fire", "smoke"],
|
||||||
|
"type_to_id": {"fire": 0, "smoke": 1},
|
||||||
|
"params": {
|
||||||
|
"enable_primary": True,
|
||||||
|
"primary_conf": 0.99,
|
||||||
|
"secondary_conf": 0.99,
|
||||||
|
"final_conf": 0.99,
|
||||||
|
"enable_multi_scale": True,
|
||||||
|
"multi_scales": [0.75, 1.0, 1.25],
|
||||||
|
"enable_secondary": True,
|
||||||
|
"slice_size": 512,
|
||||||
|
"overlap_ratio": 0.3,
|
||||||
|
"weight_primary": 0.4,
|
||||||
|
"weight_secondary": 0.6
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# SAHI自适应切片配置
|
||||||
|
SLICE_RULES = [
|
||||||
|
(12_000_000, (384, 0.35)),
|
||||||
|
(3_000_000, (512, 0.3)),
|
||||||
|
(0, (640, 0.25))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionResponse(BaseModel):
|
||||||
|
hasTarget: int
|
||||||
|
originalImgSize: List[int]
|
||||||
|
targets: List[dict]
|
||||||
|
processing_errors: List[str] = []
|
||||||
313
detect.py
Normal file
313
detect.py
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from sahi import AutoDetectionModel
|
||||||
|
from sahi.predict import get_sliced_prediction
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
from config import DEVICE, DEFAULT_IOU, DEFAULT_MIN_SIZE, DEFAULT_POS_THRESH, SLICE_RULES, DEFAULT_CONF
|
||||||
|
|
||||||
|
|
||||||
|
class YOLODetector:
|
||||||
|
def __init__(self, model_path, params, type_to_id):
|
||||||
|
# 加载YOLO模型
|
||||||
|
self.model = YOLO(model_path)
|
||||||
|
self.model.to(DEVICE)
|
||||||
|
self.class_names = self.model.names
|
||||||
|
self.type_to_id = type_to_id
|
||||||
|
|
||||||
|
self.params = params
|
||||||
|
self.enable_primary = params.get("enable_primary", True)
|
||||||
|
self.primary_conf = params.get("primary_conf", DEFAULT_CONF) # 初级检测阈值
|
||||||
|
self.secondary_conf = params.get("secondary_conf", DEFAULT_CONF) # 次级检测阈值
|
||||||
|
self.final_conf = params.get("final_conf", DEFAULT_CONF) # 最终展示阈值
|
||||||
|
|
||||||
|
# SAHI模型
|
||||||
|
self.sahi_model = None
|
||||||
|
if params["enable_secondary"]:
|
||||||
|
self.sahi_model = AutoDetectionModel.from_pretrained(
|
||||||
|
model_type='yolov8',
|
||||||
|
model_path=model_path,
|
||||||
|
confidence_threshold=self.secondary_conf,
|
||||||
|
device=DEVICE
|
||||||
|
)
|
||||||
|
|
||||||
|
# 统计
|
||||||
|
self.stats = defaultdict(int)
|
||||||
|
|
||||||
|
def get_adaptive_slice(self, total_pixels):
|
||||||
|
"""自适应切片参数"""
|
||||||
|
for pixel_thresh, (size, overlap) in SLICE_RULES:
|
||||||
|
if total_pixels > pixel_thresh:
|
||||||
|
return size, overlap
|
||||||
|
return self.params["slice_size"], self.params["overlap_ratio"]
|
||||||
|
|
||||||
|
def multi_scale_detect(self, img_path):
|
||||||
|
"""多尺度检测(使用模型专属初级阈值)"""
|
||||||
|
detections = []
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
|
||||||
|
for scale in self.params["multi_scales"]:
|
||||||
|
if scale == 1.0:
|
||||||
|
# 原尺度检测
|
||||||
|
results = self.model(
|
||||||
|
img_path,
|
||||||
|
conf=self.primary_conf, # 模型专属初级阈值
|
||||||
|
device=DEVICE,
|
||||||
|
classes=self.target_ids,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 缩放检测
|
||||||
|
nw, nh = int(w * scale), int(h * scale)
|
||||||
|
scaled_img = cv2.resize(img, (nw, nh))
|
||||||
|
temp_path = f"temp_scale_{scale}.jpg"
|
||||||
|
cv2.imwrite(temp_path, scaled_img)
|
||||||
|
|
||||||
|
results = self.model(
|
||||||
|
temp_path,
|
||||||
|
conf=self.primary_conf, # 模型专属初级阈值
|
||||||
|
device=DEVICE,
|
||||||
|
classes=self.target_ids,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
os.remove(temp_path)
|
||||||
|
|
||||||
|
# 解析结果(核心修复:增加对result.boxes为None的判断)
|
||||||
|
for result in results:
|
||||||
|
# 检查boxes是否存在且非空
|
||||||
|
if result.boxes is None:
|
||||||
|
continue
|
||||||
|
for box in result.boxes:
|
||||||
|
bbox = box.xyxy[0].tolist()
|
||||||
|
if scale != 1.0:
|
||||||
|
bbox = [coord / scale for coord in bbox]
|
||||||
|
|
||||||
|
detections.append({
|
||||||
|
"box": bbox,
|
||||||
|
"conf": box.conf[0].item(),
|
||||||
|
"class": box.cls[0].item(),
|
||||||
|
"class_name": self.class_names[int(box.cls[0])],
|
||||||
|
"source": "primary"
|
||||||
|
})
|
||||||
|
|
||||||
|
return detections
|
||||||
|
|
||||||
|
def primary_detect(self, img_path):
|
||||||
|
"""初次检测(使用模型专属初级阈值)- 新增enable_primary判断"""
|
||||||
|
# 新增:如果禁用一级检测,直接返回空列表
|
||||||
|
if not self.enable_primary:
|
||||||
|
self.stats["primary"] = 0
|
||||||
|
print(" 一级检测已禁用,跳过初级检测")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if self.params["enable_multi_scale"]:
|
||||||
|
detections = self.multi_scale_detect(img_path)
|
||||||
|
else:
|
||||||
|
results = self.model(
|
||||||
|
img_path,
|
||||||
|
conf=self.primary_conf, # 模型专属初级阈值
|
||||||
|
device=DEVICE,
|
||||||
|
classes=self.target_ids,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
# 解析结果(核心修复:增加对result.boxes为None的判断)
|
||||||
|
detections = []
|
||||||
|
for result in results:
|
||||||
|
# 检查boxes是否存在且非空
|
||||||
|
if result.boxes is None:
|
||||||
|
continue
|
||||||
|
for box in result.boxes:
|
||||||
|
detections.append({
|
||||||
|
"box": box.xyxy[0].tolist(),
|
||||||
|
"conf": box.conf[0].item(),
|
||||||
|
"class": box.cls[0].item(),
|
||||||
|
"class_name": self.class_names[int(box.cls[0])],
|
||||||
|
"source": "primary"
|
||||||
|
})
|
||||||
|
|
||||||
|
self.stats["primary"] = len(detections)
|
||||||
|
return detections
|
||||||
|
|
||||||
|
def secondary_detect(self, img_path):
|
||||||
|
"""SAHI切片检测(已在初始化时使用模型专属次级阈值)"""
|
||||||
|
if not self.params["enable_secondary"] or not self.sahi_model:
|
||||||
|
return []
|
||||||
|
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
total_pixels = w * h
|
||||||
|
slice_size, overlap = self.get_adaptive_slice(total_pixels)
|
||||||
|
|
||||||
|
# SAHI切片预测
|
||||||
|
sliced_results = get_sliced_prediction(
|
||||||
|
img_path,
|
||||||
|
self.sahi_model,
|
||||||
|
slice_height=slice_size,
|
||||||
|
slice_width=slice_size,
|
||||||
|
overlap_height_ratio=overlap,
|
||||||
|
overlap_width_ratio=overlap,
|
||||||
|
verbose=0
|
||||||
|
)
|
||||||
|
|
||||||
|
detections = []
|
||||||
|
for obj in sliced_results.object_prediction_list:
|
||||||
|
if self.target_ids and obj.category.id not in self.target_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
bbox = obj.bbox.to_xyxy()
|
||||||
|
bw, bh = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||||
|
|
||||||
|
if bw >= DEFAULT_MIN_SIZE and bh >= DEFAULT_MIN_SIZE:
|
||||||
|
detections.append({
|
||||||
|
"box": bbox,
|
||||||
|
"conf": obj.score.value,
|
||||||
|
"class": obj.category.id,
|
||||||
|
"class_name": obj.category.name,
|
||||||
|
"source": "secondary"
|
||||||
|
})
|
||||||
|
|
||||||
|
self.stats["secondary"] = len(detections)
|
||||||
|
return detections
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_iou(box1, box2):
|
||||||
|
"""计算IoU"""
|
||||||
|
x11, y11, x21, y21 = box1
|
||||||
|
x12, y12, x22, y22 = box2
|
||||||
|
|
||||||
|
inter_x1 = max(x11, x12)
|
||||||
|
inter_y1 = max(y11, y12)
|
||||||
|
inter_x2 = min(x21, x22)
|
||||||
|
inter_y2 = min(y21, y22)
|
||||||
|
|
||||||
|
inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
|
||||||
|
area1 = (x21 - x11) * (y21 - y11)
|
||||||
|
area2 = (x22 - x12) * (y22 - y12)
|
||||||
|
union_area = area1 + area2 - inter_area
|
||||||
|
|
||||||
|
return inter_area / union_area if union_area > 0 else 0
|
||||||
|
|
||||||
|
def merge_detections(self, primary_dets, secondary_dets):
|
||||||
|
"""融合检测结果"""
|
||||||
|
if not primary_dets:
|
||||||
|
return secondary_dets
|
||||||
|
if not secondary_dets:
|
||||||
|
return primary_dets
|
||||||
|
|
||||||
|
# 加权置信度
|
||||||
|
all_dets = []
|
||||||
|
for det in primary_dets:
|
||||||
|
det["weighted_conf"] = det["conf"] * self.params["weight_primary"]
|
||||||
|
all_dets.append(det)
|
||||||
|
for det in secondary_dets:
|
||||||
|
det["weighted_conf"] = det["conf"] * self.params["weight_secondary"]
|
||||||
|
all_dets.append(det)
|
||||||
|
|
||||||
|
# 按类别分组融合
|
||||||
|
class_groups = defaultdict(list)
|
||||||
|
for det in all_dets:
|
||||||
|
class_groups[det["class"]].append(det)
|
||||||
|
|
||||||
|
merged = []
|
||||||
|
for cls_id, cls_dets in class_groups.items():
|
||||||
|
cls_dets.sort(key=lambda x: x["weighted_conf"], reverse=True)
|
||||||
|
suppressed = [False] * len(cls_dets)
|
||||||
|
|
||||||
|
for i in range(len(cls_dets)):
|
||||||
|
if suppressed[i]:
|
||||||
|
continue
|
||||||
|
merged.append(cls_dets[i])
|
||||||
|
for j in range(i + 1, len(cls_dets)):
|
||||||
|
if not suppressed[j] and self.calculate_iou(cls_dets[i]["box"], cls_dets[j]["box"]) > DEFAULT_IOU:
|
||||||
|
suppressed[j] = True
|
||||||
|
|
||||||
|
self.stats["merged"] = len(merged)
|
||||||
|
return merged
|
||||||
|
|
||||||
|
def post_process(self, detections):
|
||||||
|
"""后处理(使用模型专属最终阈值)"""
|
||||||
|
# 置信度过滤:模型专属最终阈值
|
||||||
|
filtered = [det for det in detections if det["conf"] >= self.final_conf]
|
||||||
|
|
||||||
|
# 位置去重
|
||||||
|
final_dets = []
|
||||||
|
for curr_det in filtered:
|
||||||
|
curr_cx = (curr_det["box"][0] + curr_det["box"][2]) / 2
|
||||||
|
curr_cy = (curr_det["box"][1] + curr_det["box"][3]) / 2
|
||||||
|
curr_cls = curr_det["class"]
|
||||||
|
duplicate = False
|
||||||
|
|
||||||
|
for idx, exist_det in enumerate(final_dets):
|
||||||
|
if exist_det["class"] != curr_cls:
|
||||||
|
continue
|
||||||
|
|
||||||
|
exist_cx = (exist_det["box"][0] + exist_det["box"][2]) / 2
|
||||||
|
exist_cy = (exist_det["box"][1] + exist_det["box"][3]) / 2
|
||||||
|
dist = np.sqrt((curr_cx - exist_cx) **2 + (curr_cy - exist_cy)** 2)
|
||||||
|
|
||||||
|
if dist < DEFAULT_POS_THRESH:
|
||||||
|
duplicate = True
|
||||||
|
if curr_det["conf"] > exist_det["conf"]:
|
||||||
|
final_dets[idx] = curr_det
|
||||||
|
break
|
||||||
|
|
||||||
|
if not duplicate:
|
||||||
|
final_dets.append(curr_det)
|
||||||
|
|
||||||
|
self.stats["final"] = len(final_dets)
|
||||||
|
return final_dets
|
||||||
|
|
||||||
|
def format_results(self, detections):
|
||||||
|
"""格式化结果"""
|
||||||
|
formatted = []
|
||||||
|
for det in detections:
|
||||||
|
x1, y1, x2, y2 = det["box"]
|
||||||
|
formatted.append({
|
||||||
|
"type": det["class_name"],
|
||||||
|
"size": [int(round(x2 - x1)), int(round(y2 - y1))],
|
||||||
|
"leftTopPoint": [int(round(x1)), int(round(y1))],
|
||||||
|
"score": round(det["conf"], 4),
|
||||||
|
})
|
||||||
|
return formatted
|
||||||
|
|
||||||
|
def get_detection_stats(self):
|
||||||
|
"""获取检测统计信息"""
|
||||||
|
return dict(self.stats)
|
||||||
|
|
||||||
|
def detect(self, img_path, target_types=None):
|
||||||
|
"""完整检测流程"""
|
||||||
|
# 重置统计
|
||||||
|
self.stats = defaultdict(int)
|
||||||
|
|
||||||
|
# 设置目标类别
|
||||||
|
if target_types:
|
||||||
|
self.target_ids = [self.type_to_id[cls] for cls in target_types if cls in self.type_to_id]
|
||||||
|
else:
|
||||||
|
self.target_ids = None
|
||||||
|
|
||||||
|
# 执行检测
|
||||||
|
primary_dets = self.primary_detect(img_path)
|
||||||
|
print(f" 初级检测后: {self.stats['primary']} 个目标")
|
||||||
|
|
||||||
|
if self.params["enable_secondary"]:
|
||||||
|
secondary_dets = self.secondary_detect(img_path)
|
||||||
|
print(f" 次级检测后: {self.stats['secondary']} 个目标")
|
||||||
|
merged_dets = self.merge_detections(primary_dets, secondary_dets)
|
||||||
|
print(f" 融合去重后: {self.stats['merged']} 个目标")
|
||||||
|
else:
|
||||||
|
merged_dets = primary_dets
|
||||||
|
print(f" 次级检测未启用")
|
||||||
|
|
||||||
|
# 后处理
|
||||||
|
processed_dets = self.post_process(merged_dets)
|
||||||
|
print(f" 过滤低置信度后: {self.stats['final']} 个目标")
|
||||||
|
|
||||||
|
print(" 最终检测目标详情:")
|
||||||
|
for idx, det in enumerate(processed_dets, 1):
|
||||||
|
print(f" 目标{idx} - 类型:{det['class_name']},置信度:{det['conf']:.4f}")
|
||||||
|
|
||||||
|
return self.format_results(processed_dets)
|
||||||
125
main.py
Normal file
125
main.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
import io
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import uvicorn
|
||||||
|
from PIL import Image
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel, HttpUrl
|
||||||
|
|
||||||
|
from config import DetectionResponse
|
||||||
|
from process import detect_large_image_from_url
|
||||||
|
|
||||||
|
# 全局检测管理器
|
||||||
|
detector_manager = None
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
global detector_manager
|
||||||
|
try:
|
||||||
|
from manager import UnifiedDetectionManager
|
||||||
|
detector_manager = UnifiedDetectionManager()
|
||||||
|
print("检测管理器初始化成功")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"初始化失败:{str(e)}")
|
||||||
|
raise
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan, title="目标检测API", version="1.0.0")
|
||||||
|
|
||||||
|
# 配置跨域请求
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionRequest(BaseModel):
|
||||||
|
type: str
|
||||||
|
url: HttpUrl
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionProcessRequest(BaseModel):
|
||||||
|
url: HttpUrl
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/detect_image", response_model=DetectionResponse)
|
||||||
|
async def run_detection_image(request: DetectionRequest):
|
||||||
|
# 解析检测类型
|
||||||
|
requested_types = {t.strip().lower() for t in request.type.split(',') if t.strip()}
|
||||||
|
print(f"请求的检测类型: {requested_types}")
|
||||||
|
if not requested_types:
|
||||||
|
raise HTTPException(status_code=400, detail="未指定检测类型")
|
||||||
|
|
||||||
|
# 下载图片
|
||||||
|
try:
|
||||||
|
response = requests.get(str(request.url), timeout=15)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# 获取图片尺寸
|
||||||
|
with Image.open(io.BytesIO(response.content)) as img:
|
||||||
|
img_size = [img.width, img.height]
|
||||||
|
|
||||||
|
# 创建临时文件
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
|
||||||
|
temp_file.write(response.content)
|
||||||
|
temp_path = temp_file.name
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"图片处理失败:{str(e)}")
|
||||||
|
|
||||||
|
# 执行检测
|
||||||
|
results = []
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
detection_results = detector_manager.detect(temp_path, ",".join(requested_types))
|
||||||
|
if detection_results:
|
||||||
|
results = detection_results
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(f"检测失败:{str(e)}")
|
||||||
|
finally:
|
||||||
|
# 清理临时文件
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.remove(temp_path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"hasTarget": 1 if results else 0,
|
||||||
|
"originalImgSize": img_size,
|
||||||
|
"targets": results,
|
||||||
|
"processing_errors": errors
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/detect_process", response_model=DetectionResponse)
|
||||||
|
async def run_detection_process(request: DetectionProcessRequest):
|
||||||
|
return detect_large_image_from_url(str(request.url))
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/supported_types")
|
||||||
|
async def get_supported_types():
|
||||||
|
if detector_manager:
|
||||||
|
info = detector_manager.get_available_info()
|
||||||
|
return {
|
||||||
|
"supported_types": info["supported_types"],
|
||||||
|
}
|
||||||
|
return {"supported_types": []}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", default="0.0.0.0")
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
parser.add_argument("--reload", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
uvicorn.run("main:app", host=args.host, port=args.port, reload=args.reload)
|
||||||
116
manager.py
Normal file
116
manager.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from config import MODEL_CONFIGS
|
||||||
|
from detect import YOLODetector
|
||||||
|
|
||||||
|
|
||||||
|
class UnifiedDetectionManager:
|
||||||
|
"""统一检测管理器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.detectors = {} # 检测器实例
|
||||||
|
self.type_to_model = {} # 类别到模型映射
|
||||||
|
self.loaded_models = [] # 已加载模型
|
||||||
|
self.type_to_id = {} # 全局类别ID映射
|
||||||
|
|
||||||
|
self._load_models()
|
||||||
|
|
||||||
|
def _load_models(self):
|
||||||
|
"""加载所有模型"""
|
||||||
|
if not MODEL_CONFIGS:
|
||||||
|
raise ValueError("模型配置为空")
|
||||||
|
|
||||||
|
for model_name, config in MODEL_CONFIGS.items():
|
||||||
|
try:
|
||||||
|
model_path = config["model_path"]
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
print(f"跳过 {model_name}: 模型文件不存在 - {model_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 创建检测器(自动传递新增的enable_primary配置)
|
||||||
|
detector = YOLODetector(
|
||||||
|
model_path=model_path,
|
||||||
|
params=config["params"],
|
||||||
|
type_to_id=config["type_to_id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存状态
|
||||||
|
self.detectors[model_name] = detector
|
||||||
|
self.loaded_models.append(model_name)
|
||||||
|
|
||||||
|
# 建立映射
|
||||||
|
for det_type in config["types"]:
|
||||||
|
det_type_lower = det_type.lower()
|
||||||
|
if det_type_lower in self.type_to_model:
|
||||||
|
print(f"警告: 类别 '{det_type}' 映射冲突")
|
||||||
|
self.type_to_model[det_type_lower] = model_name
|
||||||
|
self.type_to_id[det_type_lower] = config["type_to_id"][det_type_lower]
|
||||||
|
|
||||||
|
print(f"加载成功: {model_name}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载失败 {model_name}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"模型加载完成: {len(self.loaded_models)}/{len(MODEL_CONFIGS)}")
|
||||||
|
print(f"支持类别: {list(self.type_to_model.keys())}")
|
||||||
|
|
||||||
|
def parse_types(self, types_str):
|
||||||
|
"""解析检测类型"""
|
||||||
|
if not types_str:
|
||||||
|
raise ValueError("检测类型为空")
|
||||||
|
|
||||||
|
# 清理输入
|
||||||
|
requested_types = list(set(t.strip().lower() for t in types_str.split(',') if t.strip()))
|
||||||
|
|
||||||
|
# 按模型分组
|
||||||
|
model_type_map = defaultdict(list)
|
||||||
|
for det_type in requested_types:
|
||||||
|
if det_type in self.type_to_model:
|
||||||
|
model_name = self.type_to_model[det_type]
|
||||||
|
model_type_map[model_name].append(det_type)
|
||||||
|
else:
|
||||||
|
print(f"忽略未知类别: {det_type}")
|
||||||
|
|
||||||
|
if not model_type_map:
|
||||||
|
raise ValueError("无有效检测类别")
|
||||||
|
|
||||||
|
return model_type_map
|
||||||
|
|
||||||
|
def detect(self, img_path, detection_types):
|
||||||
|
"""执行检测"""
|
||||||
|
if not os.path.exists(img_path):
|
||||||
|
raise FileNotFoundError(f"图像不存在: {img_path}")
|
||||||
|
|
||||||
|
# 解析类型
|
||||||
|
model_type_map = self.parse_types(detection_types)
|
||||||
|
|
||||||
|
# 执行检测(自动适配enable_primary配置)
|
||||||
|
all_results = []
|
||||||
|
for model_name, target_types in model_type_map.items():
|
||||||
|
if model_name not in self.detectors:
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"检测: {model_name} -> {target_types}")
|
||||||
|
try:
|
||||||
|
results = self.detectors[model_name].detect(img_path, target_types)
|
||||||
|
all_results.extend(results)
|
||||||
|
|
||||||
|
# 获取详细统计信息
|
||||||
|
stats = self.detectors[model_name].get_detection_stats()
|
||||||
|
print(f" {model_name}详细统计: {stats}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"检测失败 {model_name}: {str(e)}")
|
||||||
|
|
||||||
|
print(f"检测完成: 总共 {len(all_results)} 个结果")
|
||||||
|
return all_results
|
||||||
|
|
||||||
|
def get_available_info(self):
|
||||||
|
"""获取可用信息"""
|
||||||
|
return {
|
||||||
|
"loaded_models": self.loaded_models,
|
||||||
|
"supported_types": list(self.type_to_model.keys()),
|
||||||
|
"type_to_model": self.type_to_model
|
||||||
|
}
|
||||||
BIN
models/fire_smoke_model/best.pt
Normal file
BIN
models/fire_smoke_model/best.pt
Normal file
Binary file not shown.
BIN
models/ppe_state_model/best.pt
Normal file
BIN
models/ppe_state_model/best.pt
Normal file
Binary file not shown.
BIN
models/solor_bracket_model/best.pt
Normal file
BIN
models/solor_bracket_model/best.pt
Normal file
Binary file not shown.
180
process.py
Normal file
180
process.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import List
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from tqdm import tqdm
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
|
||||||
|
# 定义返回值模型
|
||||||
|
class DetectionResponse(BaseModel):
|
||||||
|
hasTarget: int
|
||||||
|
originalImgSize: List[int]
|
||||||
|
targets: List[dict]
|
||||||
|
processing_errors: List[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
def download_large_file(url, chunk_size=1024 * 1024):
|
||||||
|
"""下载大型文件到临时文件、返回临时文件路径"""
|
||||||
|
try:
|
||||||
|
response = requests.get(url, stream=True, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
file_size = int(response.headers.get('Content-Length', 0))
|
||||||
|
|
||||||
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') # 适配png格式
|
||||||
|
temp_file_path = temp_file.name
|
||||||
|
temp_file.close()
|
||||||
|
|
||||||
|
with open(temp_file_path, 'wb') as f, tqdm(
|
||||||
|
total=file_size, unit='B', unit_scale=True,
|
||||||
|
desc=f"下载 {os.path.basename(urlparse(url).path)}"
|
||||||
|
) as pbar:
|
||||||
|
for chunk in response.iter_content(chunk_size=chunk_size):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
pbar.update(len(chunk))
|
||||||
|
|
||||||
|
return temp_file_path
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"下载失败: {str(e)}"
|
||||||
|
print(error_msg)
|
||||||
|
if 'temp_file_path' in locals():
|
||||||
|
try:
|
||||||
|
os.remove(temp_file_path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
raise Exception(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
def slice_large_image(image_path, slice_size=1024, overlap=100):
|
||||||
|
"""切分大图为切片、返回切片数据和位置信息"""
|
||||||
|
img = cv2.imread(image_path)
|
||||||
|
if img is None:
|
||||||
|
raise ValueError(f"无法读取图像: {image_path}")
|
||||||
|
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
step = slice_size - overlap
|
||||||
|
num_rows = (h + step - 1) // step
|
||||||
|
num_cols = (w + step - 1) // step
|
||||||
|
|
||||||
|
slices = []
|
||||||
|
for i in range(num_rows):
|
||||||
|
for j in range(num_cols):
|
||||||
|
y1 = i * step
|
||||||
|
x1 = j * step
|
||||||
|
y2 = min(y1 + slice_size, h)
|
||||||
|
x2 = min(x1 + slice_size, w)
|
||||||
|
|
||||||
|
if y2 - y1 < slice_size:
|
||||||
|
y1 = max(0, y2 - slice_size)
|
||||||
|
if x2 - x1 < slice_size:
|
||||||
|
x1 = max(0, x2 - slice_size)
|
||||||
|
|
||||||
|
slice_img = img[y1:y2, x1:x2]
|
||||||
|
slices.append((x1, y1, slice_img))
|
||||||
|
|
||||||
|
return slices, (h, w)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_detection_info(result, slice_offset_x, slice_offset_y):
|
||||||
|
"""从YOLO OBB结果中提取检测框信息(修正宽高计算)"""
|
||||||
|
detections = []
|
||||||
|
|
||||||
|
if result.obb is not None and len(result.obb) > 0:
|
||||||
|
obb_data = result.obb
|
||||||
|
obb_xyxy = obb_data.xyxy.cpu().numpy()
|
||||||
|
classes = obb_data.cls.cpu().numpy()
|
||||||
|
confidences = obb_data.conf.cpu().numpy()
|
||||||
|
|
||||||
|
for i in range(len(obb_data)):
|
||||||
|
x1_slice, y1_slice, x2_slice, y2_slice = obb_xyxy[i]
|
||||||
|
# 计算实际宽高(x方向为宽,y方向为高)
|
||||||
|
width = x2_slice - x1_slice
|
||||||
|
height = y2_slice - y1_slice
|
||||||
|
|
||||||
|
# 转换为全局坐标
|
||||||
|
x1_global = x1_slice + slice_offset_x
|
||||||
|
y1_global = y1_slice + slice_offset_y
|
||||||
|
|
||||||
|
cls_id = int(classes[i])
|
||||||
|
confidence = float(confidences[i])
|
||||||
|
class_name = result.names[cls_id]
|
||||||
|
|
||||||
|
detection_info = {
|
||||||
|
"type": class_name,
|
||||||
|
"size": [int(round(width)), int(round(height))],
|
||||||
|
"leftTopPoint": [int(round(x1_global)), int(round(y1_global))],
|
||||||
|
"score": round(confidence, 4)
|
||||||
|
}
|
||||||
|
detections.append(detection_info)
|
||||||
|
|
||||||
|
return detections
|
||||||
|
|
||||||
|
|
||||||
|
def detect_large_image_from_url(image_url: str, slice_size: int = 1024, overlap: int = 100) -> DetectionResponse:
|
||||||
|
"""
|
||||||
|
封装后的检测方法:从图片URL处理大图、返回DetectionResponse对象
|
||||||
|
"""
|
||||||
|
# 动态拼接固定model_path(当前文件同级目录下)
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
model_path = os.path.join(current_dir, "models", "solor_bracket_model", "best.pt")
|
||||||
|
|
||||||
|
processing_errors = []
|
||||||
|
all_detections = []
|
||||||
|
original_size = [0, 0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 验证模型文件是否存在
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise FileNotFoundError(f"模型文件不存在:{model_path}")
|
||||||
|
|
||||||
|
# 下载图像
|
||||||
|
temp_file_path = download_large_file(image_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 切分图像
|
||||||
|
slices_info, (h, w) = slice_large_image(temp_file_path, slice_size, overlap)
|
||||||
|
original_size = [w, h]
|
||||||
|
print(f"完成切片: 共 {len(slices_info)} 个切片")
|
||||||
|
|
||||||
|
# 加载模型并预测
|
||||||
|
model = YOLO(model_path)
|
||||||
|
print("开始逐张预测切片...")
|
||||||
|
|
||||||
|
for i, (x1, y1, slice_img) in enumerate(slices_info, 1):
|
||||||
|
print(f"预测第 {i}/{len(slices_info)} 个切片")
|
||||||
|
result = model(slice_img, conf=0.5, verbose=False)[0]
|
||||||
|
slice_detections = extract_detection_info(result, x1, y1)
|
||||||
|
all_detections.extend(slice_detections)
|
||||||
|
print(f" 本切片检测到 {len(slice_detections)} 个目标")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 确保临时文件删除
|
||||||
|
if os.path.exists(temp_file_path):
|
||||||
|
try:
|
||||||
|
os.remove(temp_file_path)
|
||||||
|
print("临时文件已删除")
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"删除临时文件失败: {str(e)}"
|
||||||
|
print(error_msg)
|
||||||
|
processing_errors.append(error_msg)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 捕获所有异常并记录
|
||||||
|
error_msg = str(e)
|
||||||
|
processing_errors.append(error_msg)
|
||||||
|
print(f"处理异常: {error_msg}")
|
||||||
|
|
||||||
|
# 构建并返回DetectionResponse对象
|
||||||
|
return DetectionResponse(
|
||||||
|
hasTarget=1 if len(all_detections) > 0 else 0,
|
||||||
|
originalImgSize=original_size,
|
||||||
|
targets=all_detections,
|
||||||
|
processing_errors=processing_errors
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user