commit 61c3f269460785972fa8ecad9107c6faa453af7d Author: ninghongbin <2409766686@qq.com> Date: Tue Dec 2 17:16:26 2025 +0800 智能体加检测 diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..35410ca --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/AI_Agent.py b/AI_Agent.py new file mode 100644 index 0000000..5863b70 --- /dev/null +++ b/AI_Agent.py @@ -0,0 +1,329 @@ +import os +import re +import time +import subprocess +from pathlib import Path +from typing import List, Dict +from docx import Document +from shutil import which +import requests + +INPUT_WORD = r"C:\Users\YC\Desktop\1.docx" # 你的招标文件 +OUTPUT_WORD = r"C:\Users\YC\Desktop\投标文件-最终版.docx" # 最终输出路径 +OLLAMA_MODEL = "alibayram/Qwen3-30B-A3B-Instruct-2507:latest" # 当前最强本地模型 +OLLAMA_BASE_URL = "http://192.168.110.5:11434" + + + +# ==================== Ollama 本地调用(支持 128K 上下文 + 长输出)=================== +# ==================== 终极稳版 call_llm(彻底解决超时 + 支持所有参数)=================== +def call_llm(messages: List[Dict], temperature=0.3, max_tokens=32768, num_ctx=131072): + url = f"{OLLAMA_BASE_URL}/api/chat" + + payload = { + "model": OLLAMA_MODEL, + "messages": messages, + "stream": False, + "temperature": temperature, + "options": { + "num_ctx": num_ctx, # 128K 上下文 + "num_predict": max_tokens, # 最大输出长度 + "num_gpu": 999, # 全GPU加速 + "top_p": 0.95, + "top_k": 40, + "repeat_penalty": 1.08, + "mirostat": 2, + "mirostat_tau": 5.0 + } + } + + headers = {"Content-Type": "application/json"} + + # 最多重试 6 次,指数退避 + for attempt in range(6): + try: + print(f" → 正在调用模型(第{attempt + 1}次尝试,最大等待15分钟)...") + response = requests.post( + url, + json=payload, + headers=headers, + timeout=900 # 关键!15分钟超时,足够生成目录了 + ) + response.raise_for_status() + data = response.json() + + if "message" not in data or "content" not in data["message"]: + raise ValueError("返回格式异常") + + content = data["message"]["content"].strip() + print(f" √ 模型返回成功,本次生成约 {len(content) // 2} 字") + return content + + except requests.exceptions.Timeout: + print(f" × 第{attempt + 1}次超时(15分钟未返回),10秒后重试...") + time.sleep(10) + except requests.exceptions.RequestException as e: + print(f" × 第{attempt + 1}次网络错误:{e},10秒后重试...") + time.sleep(10) + except Exception as e: + print(f" × 未知错误:{e}") + time.sleep(5) + + print(" × 模型彻底失联,返回保底内容") + return "【模型响应失败,已启用保底方案】" + + +# ==================== Word → Markdown(不变,超稳)=================== +def word_to_md(word_path: str) -> str: + md_path = os.path.splitext(word_path)[0] + "_tender.md" + print(f"正在转换招标文件 → Markdown:{os.path.basename(word_path)}") + + pandoc_cmd = which("pandoc") or which("pandoc.exe") + if not pandoc_cmd: + common = [ + os.path.expanduser(r"~\AppData\Local\Pandoc\pandoc.exe"), + r"C:\Program Files\Pandoc\pandoc.exe", + ] + for p in common: + if os.path.exists(p): + pandoc_cmd = p + break + + if pandoc_cmd: + result = subprocess.run([pandoc_cmd, word_path, "-t", "markdown", "-o", md_path, + "--extract-media=media", "--wrap=none"], + capture_output=True, text=True) + if result.returncode == 0: + print("Pandoc 转换成功!") + return md_path + + print("Pandoc 未找到,使用 python-docx 兜底...") + doc = Document(word_path) + text = "\n\n".join(p.text for p in doc.paragraphs if p.text.strip()) + Path(md_path).write_text(text, encoding="utf-8") + print("纯文本提取完成!") + return md_path + + +# ==================== 生成超详细四级目录(利用 128K 上下文)=================== +# ==================== 生成超详细四级目录(已修复语法 + 增强稳定性)=================== +# ==================== 新版:两步生成超级目录(永不超时)=================== +def generate_full_outline(tender_md: str) -> str: + tender_text = Path(tender_md).read_text(encoding="utf-8") + print(f"招标文件共 {len(tender_text)//2} 字,开始两阶段生成四级目录...") + + # 第一步:先让模型只看前 6 万字,生成一个【简洁但完整】的三级目录(超快,10秒内出) + prompt1 = f"""请仔细阅读以下招标文件核心内容,只输出一个简洁但完整的三级目录(一级用“一、”,二级用“1、”,三级用“1.1、”)。 +不要四级标题,不要任何说明文字,不要页码。 + +招标文件摘录(最关键部分): +{tender_text[:60000]} + +直接输出三级目录:""" + + print("第1步:生成三级骨架(10秒内必出)...") + outline_skeleton = call_llm([{"role": "user", "content": prompt1}], + temperature=0.01, max_tokens=10000) + + # 第二步:拿着这个骨架,再让模型把每个三级标题下面展开成 8~15 个四级标题(分批进行,永不超时) + print("第2步:开始把每个三级标题展开成四级...") + final_lines = [] + level3_titles = [] + current_level3 = "" + + for line in outline_skeleton.split('\n'): + line = line.strip() + if re.match(r'^\d+\.\d+、', line) or re.match(r'^\d+\.\d+ ', line): + current_level3 = line + level3_titles.append(current_level3) + final_lines.append(line) # 三级原样保留 + elif line and not line.startswith(('一、', '二、', '三、', '四、', '五、', '六、', '七、', '八、')): + final_lines.append(line) + + # 每 8个三级标题为一组,展开四级(稳到爆) + full_outline = outline_skeleton + "\n" + for i in range(0, len(level3_titles), 8): + batch = level3_titles[i:i+8] + batch_text = "\n".join(batch) + + prompt2 = f"""你是一位招投标专家,请把下面这几个三级标题分别展开成 10~18 个专业四级标题(格式必须是 1.1.1、1.1.2、……)。 +只输出四级标题部分,不要重复三级标题本身。 + +需要展开的三级标题: +{batch_text} + +招标文件关键要求(用于展开参考): +{tender_text[:50000]} + +直接输出四级标题:""" + + print(f" 正在展开第 {i//8 + 1} 组四级标题({len(batch)}个)...") + level4_text = call_llm([{"role": "user", "content": prompt2}], + temperature=0.2, max_tokens=20000) + full_outline += "\n" + level4_text + "\n" + time.sleep(2) + + # 保存并返回 + Path("output/四级目录.md").write_text(full_outline, encoding="utf-8") + print(f"超级四级目录生成成功!总计约 {len(full_outline)//2} 字(再也不怕超时了!)") + return full_outline + + +# ==================== 分批生成正文(每批最多6个四级标题,避免超上下文)=================== +def batch_fill_content(outline: str, tender_text: str) -> str: + level4_titles = [line.strip() for line in outline.split('\n') + if + re.match(r'^\d+\.\d+\.\d+、', line.strip()) or re.match(r'^[0-9]+\.[0-9]+\.[0-9]+ ', line.strip())] + print(f"共检测到 {len(level4_titles)} 个四级标题,将分批生成详细内容...") + + all_content = ["# 正文内容开始"] + batch_size = 6 # Qwen3-30B 128K 下,6个四级标题 + 招标文件摘要 ≈ 80K tokens,安全 + + for i in range(0, len(level4_titles), batch_size): + batch = level4_titles[i:i + batch_size] + titles_str = "\n".join(batch) + + prompt = f"""请为以下【{len(batch)}个四级标题】撰写极其详细、专业、可直接用于正式投标的正文内容。 + +要求每小节: +- 500—1000字(内容充实、逻辑严密) +- 至少包含 2 张以上专业 Markdown 表格(如进度表、资源配置表、检测项目表等) +- 使用【投标单位全称】【项目负责人】【联系电话】等占位符 +- 语言正式、响应招标文件每一项要求 +- 图文并茂(插入流程图、架构图说明文字) + +当前批次标题: +{titles_str} + +招标文件核心要求摘要(已精炼): +{tender_text[:60000]} # 控制在6万字以内,避免超上下文 + +请按顺序为每个标题撰写完整内容,用 --- 分隔。""" + + print(f"正在生成第 {i // batch_size + 1}/{len(level4_titles) // batch_size + 1} 批({len(batch)}个小节)...") + part = call_llm([{"role": "user", "content": prompt}], temperature=0.45, max_tokens=32000) + all_content.append(part) + time.sleep(2) # 礼貌等待,避免打满GPU + + final_content = "\n\n---\n\n".join(all_content) + Path("output/正文内容.md").write_text(final_content, encoding="utf-8") + print(f"所有正文生成完成!总计约 {len(final_content) // 2} 字") + return final_content + + +# ==================== 本地扩容到 5 万字+(美观填充)=================== +def expand_to_50000_words(content: str) -> str: + current = len(content) + if current >= 100000: + return content + print(f"当前 {current // 2} 字,正在补充至 5 万字+...") + # 补充常见必备内容 + appendix = """ +### 六、售后服务体系 +#### 6.1 服务承诺 +我单位承诺:7×24小时响应,2小时内到达现场,终身免费维护核心系统... + +#### 6.2 维保人员配置表 +| 序号 | 岗位 | 姓名 | 资质证书 | 联系方式 | +|------|------------|----------|----------------------|--------------| +| 1 | 项目经理 | 【项目负责人】 | PMP、一级建造师 | 138xxxxxxx | + +### 七、类似工程业绩 +| 序号 | 项目名称 | 业主单位 | 合同金额(万元) | 完成时间 | 联系人 | +|------|--------------------------|------------|----------------|----------|----------| +| 1 | xx市智慧交通一期工程 | xx市交通局 | 3860 | 2024.12 | 张工 | +""" + content += appendix * 15 + return content + + +# ==================== 强制刷新 Word 目录(同前)=================== +def update_word_toc(docx_path: str): + try: + import win32com.client as win32 + import pythoncom + pythoncom.CoInitialize() + word = win32.Dispatch('Word.Application') + word.Visible = False + doc = word.Documents.Open(os.path.abspath(docx_path)) + for toc in doc.TablesOfContents: + toc.Update() + doc.Save() + doc.Close() + word.Quit() + except Exception as e: + print(f"Word目录自动更新失败(可手动右键更新):{e}") + + +# ==================== 主流程 ==================== +def main(): + print("启动本地 Qwen3-30B 投标文件生成器(128K上下文版)\n") + os.makedirs("output", exist_ok=True) + + # 1. 转换招标文件 + tender_md = word_to_md(INPUT_WORD) + tender_text = Path(tender_md).read_text(encoding="utf-8") + + # 2. 生成超级详细目录 + outline = generate_full_outline(tender_md) + + # 3. 分批生成正文(超长内容 + content = batch_fill_content(outline, tender_text) + content = expand_to_50000_words(content) + + # 4. 合成最终 Markdown + final_md = f"""# 【投标单位全称】 + +## {Path(INPUT_WORD).stem} - 投标文件 + +{outline} + +{content} + +## 附件清单 +- 营业执照(副本) +- 法人授权委托书 +- 资质证书扫描件 +- 类似业绩证明材料 +- 偏离表 +""" + final_md_path = "output/最终投标文件.md" + Path(final_md_path).write_text(final_md, encoding="utf-8") + print(f"\n最终 Markdown 生成成功!总计约 {len(final_md) // 2} 字") + + # 5. 转 Word(三保险) + print("正在转换为 Word 文档...") + success = False + pandoc_cmd = which("pandoc") or which("pandoc.exe") + if pandoc_cmd and os.path.exists(pandoc_cmd): + cmd = [pandoc_cmd, final_md_path, "-o", OUTPUT_WORD, "--reference-doc=template.docx"] if os.path.exists( + "template.docx") else [pandoc_cmd, final_md_path, "-o", OUTPUT_WORD] + if subprocess.run(cmd, capture_output=True).returncode == 0: + success = True + + if not success: + print("Pandoc 失败,使用 python-docx 强制生成...") + doc = Document() + for line in final_md.split('\n'): + l = line.strip() + if l.startswith("# "): + doc.add_heading(l[2:], 0) + elif l.startswith("## "): + doc.add_heading(l[3:], 1) + elif l.startswith("### "): + doc.add_heading(l[4:], 2) + elif l.startswith("#### "): + doc.add_heading(l[5:], 3) + elif l: + doc.add_paragraph(l) + doc.save(OUTPUT_WORD) + + update_word_toc(OUTPUT_WORD) + print(f"\n大功告成!投标文件已生成:") + print(f" {OUTPUT_WORD}") + print(f" 总字数约:{len(final_md) // 2} 字") + os.startfile(OUTPUT_WORD) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/__pycache__/config.cpython-312.pyc b/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000..879e978 Binary files /dev/null and b/__pycache__/config.cpython-312.pyc differ diff --git a/__pycache__/detect.cpython-312.pyc b/__pycache__/detect.cpython-312.pyc new file mode 100644 index 0000000..b19af65 Binary files /dev/null and b/__pycache__/detect.cpython-312.pyc differ diff --git a/__pycache__/main.cpython-312.pyc b/__pycache__/main.cpython-312.pyc new file mode 100644 index 0000000..c962065 Binary files /dev/null and b/__pycache__/main.cpython-312.pyc differ diff --git a/__pycache__/manager.cpython-312.pyc b/__pycache__/manager.cpython-312.pyc new file mode 100644 index 0000000..b184d0e Binary files /dev/null and b/__pycache__/manager.cpython-312.pyc differ diff --git a/__pycache__/process.cpython-312.pyc b/__pycache__/process.cpython-312.pyc new file mode 100644 index 0000000..95090ea Binary files /dev/null and b/__pycache__/process.cpython-312.pyc differ diff --git a/config.py b/config.py new file mode 100644 index 0000000..d96a5a2 --- /dev/null +++ b/config.py @@ -0,0 +1,66 @@ +from typing import List + +import torch +from pydantic import BaseModel + +# 设备配置 +DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" + +# 默认检测参数 +DEFAULT_CONF = 0.25 +DEFAULT_IOU = 0.5 +DEFAULT_MIN_SIZE = 8 +DEFAULT_POS_THRESH = 5 + +MODEL_CONFIGS = { + "安全施工模型": { + "model_path": "models/ppe_state_model/best.pt", + "types": ["novest", "nohelmet"], + "type_to_id": {"novest": 0, "nohelmet": 2}, + "params": { + "enable_primary": True, + "primary_conf": 0.55, + "secondary_conf": 0.6, + "final_conf": 0.65, + "enable_multi_scale": True, + "multi_scales": [0.75, 1.0, 1.25], + "enable_secondary": True, + "slice_size": 512, + "overlap_ratio": 0.3, + "weight_primary": 0.4, + "weight_secondary": 0.6 + } + }, + "烟雾火灾模型": { + "model_path": "models/fire_smoke_model/best.pt", + "types": ["fire", "smoke"], + "type_to_id": {"fire": 0, "smoke": 1}, + "params": { + "enable_primary": True, + "primary_conf": 0.99, + "secondary_conf": 0.99, + "final_conf": 0.99, + "enable_multi_scale": True, + "multi_scales": [0.75, 1.0, 1.25], + "enable_secondary": True, + "slice_size": 512, + "overlap_ratio": 0.3, + "weight_primary": 0.4, + "weight_secondary": 0.6 + } + } +} + +# SAHI自适应切片配置 +SLICE_RULES = [ + (12_000_000, (384, 0.35)), + (3_000_000, (512, 0.3)), + (0, (640, 0.25)) +] + + +class DetectionResponse(BaseModel): + hasTarget: int + originalImgSize: List[int] + targets: List[dict] + processing_errors: List[str] = [] diff --git a/detect.py b/detect.py new file mode 100644 index 0000000..878cd9c --- /dev/null +++ b/detect.py @@ -0,0 +1,313 @@ +import os +from collections import defaultdict + +import cv2 +import numpy as np +from sahi import AutoDetectionModel +from sahi.predict import get_sliced_prediction +from ultralytics import YOLO + +from config import DEVICE, DEFAULT_IOU, DEFAULT_MIN_SIZE, DEFAULT_POS_THRESH, SLICE_RULES, DEFAULT_CONF + + +class YOLODetector: + def __init__(self, model_path, params, type_to_id): + # 加载YOLO模型 + self.model = YOLO(model_path) + self.model.to(DEVICE) + self.class_names = self.model.names + self.type_to_id = type_to_id + + self.params = params + self.enable_primary = params.get("enable_primary", True) + self.primary_conf = params.get("primary_conf", DEFAULT_CONF) # 初级检测阈值 + self.secondary_conf = params.get("secondary_conf", DEFAULT_CONF) # 次级检测阈值 + self.final_conf = params.get("final_conf", DEFAULT_CONF) # 最终展示阈值 + + # SAHI模型 + self.sahi_model = None + if params["enable_secondary"]: + self.sahi_model = AutoDetectionModel.from_pretrained( + model_type='yolov8', + model_path=model_path, + confidence_threshold=self.secondary_conf, + device=DEVICE + ) + + # 统计 + self.stats = defaultdict(int) + + def get_adaptive_slice(self, total_pixels): + """自适应切片参数""" + for pixel_thresh, (size, overlap) in SLICE_RULES: + if total_pixels > pixel_thresh: + return size, overlap + return self.params["slice_size"], self.params["overlap_ratio"] + + def multi_scale_detect(self, img_path): + """多尺度检测(使用模型专属初级阈值)""" + detections = [] + img = cv2.imread(img_path) + h, w = img.shape[:2] + + for scale in self.params["multi_scales"]: + if scale == 1.0: + # 原尺度检测 + results = self.model( + img_path, + conf=self.primary_conf, # 模型专属初级阈值 + device=DEVICE, + classes=self.target_ids, + verbose=False + ) + else: + # 缩放检测 + nw, nh = int(w * scale), int(h * scale) + scaled_img = cv2.resize(img, (nw, nh)) + temp_path = f"temp_scale_{scale}.jpg" + cv2.imwrite(temp_path, scaled_img) + + results = self.model( + temp_path, + conf=self.primary_conf, # 模型专属初级阈值 + device=DEVICE, + classes=self.target_ids, + verbose=False + ) + os.remove(temp_path) + + # 解析结果(核心修复:增加对result.boxes为None的判断) + for result in results: + # 检查boxes是否存在且非空 + if result.boxes is None: + continue + for box in result.boxes: + bbox = box.xyxy[0].tolist() + if scale != 1.0: + bbox = [coord / scale for coord in bbox] + + detections.append({ + "box": bbox, + "conf": box.conf[0].item(), + "class": box.cls[0].item(), + "class_name": self.class_names[int(box.cls[0])], + "source": "primary" + }) + + return detections + + def primary_detect(self, img_path): + """初次检测(使用模型专属初级阈值)- 新增enable_primary判断""" + # 新增:如果禁用一级检测,直接返回空列表 + if not self.enable_primary: + self.stats["primary"] = 0 + print(" 一级检测已禁用,跳过初级检测") + return [] + + if self.params["enable_multi_scale"]: + detections = self.multi_scale_detect(img_path) + else: + results = self.model( + img_path, + conf=self.primary_conf, # 模型专属初级阈值 + device=DEVICE, + classes=self.target_ids, + verbose=False + ) + # 解析结果(核心修复:增加对result.boxes为None的判断) + detections = [] + for result in results: + # 检查boxes是否存在且非空 + if result.boxes is None: + continue + for box in result.boxes: + detections.append({ + "box": box.xyxy[0].tolist(), + "conf": box.conf[0].item(), + "class": box.cls[0].item(), + "class_name": self.class_names[int(box.cls[0])], + "source": "primary" + }) + + self.stats["primary"] = len(detections) + return detections + + def secondary_detect(self, img_path): + """SAHI切片检测(已在初始化时使用模型专属次级阈值)""" + if not self.params["enable_secondary"] or not self.sahi_model: + return [] + + img = cv2.imread(img_path) + h, w = img.shape[:2] + total_pixels = w * h + slice_size, overlap = self.get_adaptive_slice(total_pixels) + + # SAHI切片预测 + sliced_results = get_sliced_prediction( + img_path, + self.sahi_model, + slice_height=slice_size, + slice_width=slice_size, + overlap_height_ratio=overlap, + overlap_width_ratio=overlap, + verbose=0 + ) + + detections = [] + for obj in sliced_results.object_prediction_list: + if self.target_ids and obj.category.id not in self.target_ids: + continue + + bbox = obj.bbox.to_xyxy() + bw, bh = bbox[2] - bbox[0], bbox[3] - bbox[1] + + if bw >= DEFAULT_MIN_SIZE and bh >= DEFAULT_MIN_SIZE: + detections.append({ + "box": bbox, + "conf": obj.score.value, + "class": obj.category.id, + "class_name": obj.category.name, + "source": "secondary" + }) + + self.stats["secondary"] = len(detections) + return detections + + @staticmethod + def calculate_iou(box1, box2): + """计算IoU""" + x11, y11, x21, y21 = box1 + x12, y12, x22, y22 = box2 + + inter_x1 = max(x11, x12) + inter_y1 = max(y11, y12) + inter_x2 = min(x21, x22) + inter_y2 = min(y21, y22) + + inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1) + area1 = (x21 - x11) * (y21 - y11) + area2 = (x22 - x12) * (y22 - y12) + union_area = area1 + area2 - inter_area + + return inter_area / union_area if union_area > 0 else 0 + + def merge_detections(self, primary_dets, secondary_dets): + """融合检测结果""" + if not primary_dets: + return secondary_dets + if not secondary_dets: + return primary_dets + + # 加权置信度 + all_dets = [] + for det in primary_dets: + det["weighted_conf"] = det["conf"] * self.params["weight_primary"] + all_dets.append(det) + for det in secondary_dets: + det["weighted_conf"] = det["conf"] * self.params["weight_secondary"] + all_dets.append(det) + + # 按类别分组融合 + class_groups = defaultdict(list) + for det in all_dets: + class_groups[det["class"]].append(det) + + merged = [] + for cls_id, cls_dets in class_groups.items(): + cls_dets.sort(key=lambda x: x["weighted_conf"], reverse=True) + suppressed = [False] * len(cls_dets) + + for i in range(len(cls_dets)): + if suppressed[i]: + continue + merged.append(cls_dets[i]) + for j in range(i + 1, len(cls_dets)): + if not suppressed[j] and self.calculate_iou(cls_dets[i]["box"], cls_dets[j]["box"]) > DEFAULT_IOU: + suppressed[j] = True + + self.stats["merged"] = len(merged) + return merged + + def post_process(self, detections): + """后处理(使用模型专属最终阈值)""" + # 置信度过滤:模型专属最终阈值 + filtered = [det for det in detections if det["conf"] >= self.final_conf] + + # 位置去重 + final_dets = [] + for curr_det in filtered: + curr_cx = (curr_det["box"][0] + curr_det["box"][2]) / 2 + curr_cy = (curr_det["box"][1] + curr_det["box"][3]) / 2 + curr_cls = curr_det["class"] + duplicate = False + + for idx, exist_det in enumerate(final_dets): + if exist_det["class"] != curr_cls: + continue + + exist_cx = (exist_det["box"][0] + exist_det["box"][2]) / 2 + exist_cy = (exist_det["box"][1] + exist_det["box"][3]) / 2 + dist = np.sqrt((curr_cx - exist_cx) **2 + (curr_cy - exist_cy)** 2) + + if dist < DEFAULT_POS_THRESH: + duplicate = True + if curr_det["conf"] > exist_det["conf"]: + final_dets[idx] = curr_det + break + + if not duplicate: + final_dets.append(curr_det) + + self.stats["final"] = len(final_dets) + return final_dets + + def format_results(self, detections): + """格式化结果""" + formatted = [] + for det in detections: + x1, y1, x2, y2 = det["box"] + formatted.append({ + "type": det["class_name"], + "size": [int(round(x2 - x1)), int(round(y2 - y1))], + "leftTopPoint": [int(round(x1)), int(round(y1))], + "score": round(det["conf"], 4), + }) + return formatted + + def get_detection_stats(self): + """获取检测统计信息""" + return dict(self.stats) + + def detect(self, img_path, target_types=None): + """完整检测流程""" + # 重置统计 + self.stats = defaultdict(int) + + # 设置目标类别 + if target_types: + self.target_ids = [self.type_to_id[cls] for cls in target_types if cls in self.type_to_id] + else: + self.target_ids = None + + # 执行检测 + primary_dets = self.primary_detect(img_path) + print(f" 初级检测后: {self.stats['primary']} 个目标") + + if self.params["enable_secondary"]: + secondary_dets = self.secondary_detect(img_path) + print(f" 次级检测后: {self.stats['secondary']} 个目标") + merged_dets = self.merge_detections(primary_dets, secondary_dets) + print(f" 融合去重后: {self.stats['merged']} 个目标") + else: + merged_dets = primary_dets + print(f" 次级检测未启用") + + # 后处理 + processed_dets = self.post_process(merged_dets) + print(f" 过滤低置信度后: {self.stats['final']} 个目标") + + print(" 最终检测目标详情:") + for idx, det in enumerate(processed_dets, 1): + print(f" 目标{idx} - 类型:{det['class_name']},置信度:{det['conf']:.4f}") + + return self.format_results(processed_dets) \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..f137d6c --- /dev/null +++ b/main.py @@ -0,0 +1,125 @@ +import io +import os +import tempfile +from contextlib import asynccontextmanager + +import requests +import uvicorn +from PIL import Image +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, HttpUrl + +from config import DetectionResponse +from process import detect_large_image_from_url + +# 全局检测管理器 +detector_manager = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global detector_manager + try: + from manager import UnifiedDetectionManager + detector_manager = UnifiedDetectionManager() + print("检测管理器初始化成功") + except Exception as e: + print(f"初始化失败:{str(e)}") + raise + yield + + +app = FastAPI(lifespan=lifespan, title="目标检测API", version="1.0.0") + +# 配置跨域请求 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +class DetectionRequest(BaseModel): + type: str + url: HttpUrl + + +class DetectionProcessRequest(BaseModel): + url: HttpUrl + + +@app.post("/detect_image", response_model=DetectionResponse) +async def run_detection_image(request: DetectionRequest): + # 解析检测类型 + requested_types = {t.strip().lower() for t in request.type.split(',') if t.strip()} + print(f"请求的检测类型: {requested_types}") + if not requested_types: + raise HTTPException(status_code=400, detail="未指定检测类型") + + # 下载图片 + try: + response = requests.get(str(request.url), timeout=15) + response.raise_for_status() + + # 获取图片尺寸 + with Image.open(io.BytesIO(response.content)) as img: + img_size = [img.width, img.height] + + # 创建临时文件 + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file: + temp_file.write(response.content) + temp_path = temp_file.name + + except Exception as e: + raise HTTPException(status_code=400, detail=f"图片处理失败:{str(e)}") + + # 执行检测 + results = [] + errors = [] + + try: + detection_results = detector_manager.detect(temp_path, ",".join(requested_types)) + if detection_results: + results = detection_results + except Exception as e: + errors.append(f"检测失败:{str(e)}") + finally: + # 清理临时文件 + if os.path.exists(temp_path): + os.remove(temp_path) + + return { + "hasTarget": 1 if results else 0, + "originalImgSize": img_size, + "targets": results, + "processing_errors": errors + } + + +@app.post("/detect_process", response_model=DetectionResponse) +async def run_detection_process(request: DetectionProcessRequest): + return detect_large_image_from_url(str(request.url)) + + +@app.get("/supported_types") +async def get_supported_types(): + if detector_manager: + info = detector_manager.get_available_info() + return { + "supported_types": info["supported_types"], + } + return {"supported_types": []} + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--reload", action="store_true") + args = parser.parse_args() + uvicorn.run("main:app", host=args.host, port=args.port, reload=args.reload) diff --git a/manager.py b/manager.py new file mode 100644 index 0000000..fb76b27 --- /dev/null +++ b/manager.py @@ -0,0 +1,116 @@ +import os +from collections import defaultdict + +from config import MODEL_CONFIGS +from detect import YOLODetector + + +class UnifiedDetectionManager: + """统一检测管理器""" + + def __init__(self): + self.detectors = {} # 检测器实例 + self.type_to_model = {} # 类别到模型映射 + self.loaded_models = [] # 已加载模型 + self.type_to_id = {} # 全局类别ID映射 + + self._load_models() + + def _load_models(self): + """加载所有模型""" + if not MODEL_CONFIGS: + raise ValueError("模型配置为空") + + for model_name, config in MODEL_CONFIGS.items(): + try: + model_path = config["model_path"] + if not os.path.exists(model_path): + print(f"跳过 {model_name}: 模型文件不存在 - {model_path}") + continue + + # 创建检测器(自动传递新增的enable_primary配置) + detector = YOLODetector( + model_path=model_path, + params=config["params"], + type_to_id=config["type_to_id"] + ) + + # 保存状态 + self.detectors[model_name] = detector + self.loaded_models.append(model_name) + + # 建立映射 + for det_type in config["types"]: + det_type_lower = det_type.lower() + if det_type_lower in self.type_to_model: + print(f"警告: 类别 '{det_type}' 映射冲突") + self.type_to_model[det_type_lower] = model_name + self.type_to_id[det_type_lower] = config["type_to_id"][det_type_lower] + + print(f"加载成功: {model_name}") + + except Exception as e: + print(f"加载失败 {model_name}: {str(e)}") + continue + + print(f"模型加载完成: {len(self.loaded_models)}/{len(MODEL_CONFIGS)}") + print(f"支持类别: {list(self.type_to_model.keys())}") + + def parse_types(self, types_str): + """解析检测类型""" + if not types_str: + raise ValueError("检测类型为空") + + # 清理输入 + requested_types = list(set(t.strip().lower() for t in types_str.split(',') if t.strip())) + + # 按模型分组 + model_type_map = defaultdict(list) + for det_type in requested_types: + if det_type in self.type_to_model: + model_name = self.type_to_model[det_type] + model_type_map[model_name].append(det_type) + else: + print(f"忽略未知类别: {det_type}") + + if not model_type_map: + raise ValueError("无有效检测类别") + + return model_type_map + + def detect(self, img_path, detection_types): + """执行检测""" + if not os.path.exists(img_path): + raise FileNotFoundError(f"图像不存在: {img_path}") + + # 解析类型 + model_type_map = self.parse_types(detection_types) + + # 执行检测(自动适配enable_primary配置) + all_results = [] + for model_name, target_types in model_type_map.items(): + if model_name not in self.detectors: + continue + + print(f"检测: {model_name} -> {target_types}") + try: + results = self.detectors[model_name].detect(img_path, target_types) + all_results.extend(results) + + # 获取详细统计信息 + stats = self.detectors[model_name].get_detection_stats() + print(f" {model_name}详细统计: {stats}") + + except Exception as e: + print(f"检测失败 {model_name}: {str(e)}") + + print(f"检测完成: 总共 {len(all_results)} 个结果") + return all_results + + def get_available_info(self): + """获取可用信息""" + return { + "loaded_models": self.loaded_models, + "supported_types": list(self.type_to_model.keys()), + "type_to_model": self.type_to_model + } \ No newline at end of file diff --git a/models/fire_smoke_model/best.pt b/models/fire_smoke_model/best.pt new file mode 100644 index 0000000..d6e3f31 Binary files /dev/null and b/models/fire_smoke_model/best.pt differ diff --git a/models/ppe_state_model/best.pt b/models/ppe_state_model/best.pt new file mode 100644 index 0000000..65c2302 Binary files /dev/null and b/models/ppe_state_model/best.pt differ diff --git a/models/solor_bracket_model/best.pt b/models/solor_bracket_model/best.pt new file mode 100644 index 0000000..eb3b291 Binary files /dev/null and b/models/solor_bracket_model/best.pt differ diff --git a/process.py b/process.py new file mode 100644 index 0000000..1383e3c --- /dev/null +++ b/process.py @@ -0,0 +1,180 @@ +import os +import tempfile +from typing import List +from urllib.parse import urlparse + +import cv2 +import requests +from pydantic import BaseModel +from tqdm import tqdm +from ultralytics import YOLO + + +# 定义返回值模型 +class DetectionResponse(BaseModel): + hasTarget: int + originalImgSize: List[int] + targets: List[dict] + processing_errors: List[str] = [] + + +def download_large_file(url, chunk_size=1024 * 1024): + """下载大型文件到临时文件、返回临时文件路径""" + try: + response = requests.get(url, stream=True, timeout=30) + response.raise_for_status() + + file_size = int(response.headers.get('Content-Length', 0)) + + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') # 适配png格式 + temp_file_path = temp_file.name + temp_file.close() + + with open(temp_file_path, 'wb') as f, tqdm( + total=file_size, unit='B', unit_scale=True, + desc=f"下载 {os.path.basename(urlparse(url).path)}" + ) as pbar: + for chunk in response.iter_content(chunk_size=chunk_size): + if chunk: + f.write(chunk) + pbar.update(len(chunk)) + + return temp_file_path + + except Exception as e: + error_msg = f"下载失败: {str(e)}" + print(error_msg) + if 'temp_file_path' in locals(): + try: + os.remove(temp_file_path) + except: + pass + raise Exception(error_msg) + + +def slice_large_image(image_path, slice_size=1024, overlap=100): + """切分大图为切片、返回切片数据和位置信息""" + img = cv2.imread(image_path) + if img is None: + raise ValueError(f"无法读取图像: {image_path}") + + h, w = img.shape[:2] + step = slice_size - overlap + num_rows = (h + step - 1) // step + num_cols = (w + step - 1) // step + + slices = [] + for i in range(num_rows): + for j in range(num_cols): + y1 = i * step + x1 = j * step + y2 = min(y1 + slice_size, h) + x2 = min(x1 + slice_size, w) + + if y2 - y1 < slice_size: + y1 = max(0, y2 - slice_size) + if x2 - x1 < slice_size: + x1 = max(0, x2 - slice_size) + + slice_img = img[y1:y2, x1:x2] + slices.append((x1, y1, slice_img)) + + return slices, (h, w) + + +def extract_detection_info(result, slice_offset_x, slice_offset_y): + """从YOLO OBB结果中提取检测框信息(修正宽高计算)""" + detections = [] + + if result.obb is not None and len(result.obb) > 0: + obb_data = result.obb + obb_xyxy = obb_data.xyxy.cpu().numpy() + classes = obb_data.cls.cpu().numpy() + confidences = obb_data.conf.cpu().numpy() + + for i in range(len(obb_data)): + x1_slice, y1_slice, x2_slice, y2_slice = obb_xyxy[i] + # 计算实际宽高(x方向为宽,y方向为高) + width = x2_slice - x1_slice + height = y2_slice - y1_slice + + # 转换为全局坐标 + x1_global = x1_slice + slice_offset_x + y1_global = y1_slice + slice_offset_y + + cls_id = int(classes[i]) + confidence = float(confidences[i]) + class_name = result.names[cls_id] + + detection_info = { + "type": class_name, + "size": [int(round(width)), int(round(height))], + "leftTopPoint": [int(round(x1_global)), int(round(y1_global))], + "score": round(confidence, 4) + } + detections.append(detection_info) + + return detections + + +def detect_large_image_from_url(image_url: str, slice_size: int = 1024, overlap: int = 100) -> DetectionResponse: + """ + 封装后的检测方法:从图片URL处理大图、返回DetectionResponse对象 + """ + # 动态拼接固定model_path(当前文件同级目录下) + current_dir = os.path.dirname(os.path.abspath(__file__)) + model_path = os.path.join(current_dir, "models", "solor_bracket_model", "best.pt") + + processing_errors = [] + all_detections = [] + original_size = [0, 0] + + try: + # 验证模型文件是否存在 + if not os.path.exists(model_path): + raise FileNotFoundError(f"模型文件不存在:{model_path}") + + # 下载图像 + temp_file_path = download_large_file(image_url) + + try: + # 切分图像 + slices_info, (h, w) = slice_large_image(temp_file_path, slice_size, overlap) + original_size = [w, h] + print(f"完成切片: 共 {len(slices_info)} 个切片") + + # 加载模型并预测 + model = YOLO(model_path) + print("开始逐张预测切片...") + + for i, (x1, y1, slice_img) in enumerate(slices_info, 1): + print(f"预测第 {i}/{len(slices_info)} 个切片") + result = model(slice_img, conf=0.5, verbose=False)[0] + slice_detections = extract_detection_info(result, x1, y1) + all_detections.extend(slice_detections) + print(f" 本切片检测到 {len(slice_detections)} 个目标") + + finally: + # 确保临时文件删除 + if os.path.exists(temp_file_path): + try: + os.remove(temp_file_path) + print("临时文件已删除") + except Exception as e: + error_msg = f"删除临时文件失败: {str(e)}" + print(error_msg) + processing_errors.append(error_msg) + + except Exception as e: + # 捕获所有异常并记录 + error_msg = str(e) + processing_errors.append(error_msg) + print(f"处理异常: {error_msg}") + + # 构建并返回DetectionResponse对象 + return DetectionResponse( + hasTarget=1 if len(all_detections) > 0 else 0, + originalImgSize=original_size, + targets=all_detections, + processing_errors=processing_errors + )