Files
AI_agent_detect/main.py

322 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import io
import os
import time
import tempfile
import shutil
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional
import requests
import uvicorn
from PIL import Image
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel, HttpUrl
# 只导入AI_Agent的统一入口方法不再导入多个内部方法
from AI_Agent import generate_tender_from_input
from config import DetectionResponse
from process import detect_large_image_from_url
# 配置文件存储路径(使用绝对路径确保一致性)
BASE_DIR = Path(__file__).parent.resolve() # 项目根目录绝对路径
UPLOAD_DIR = BASE_DIR / "uploaded_files"
OUTPUT_DIR = BASE_DIR / "generated_tenders"
UPLOAD_DIR.mkdir(exist_ok=True)
OUTPUT_DIR.mkdir(exist_ok=True)
# 全局检测管理器
detector_manager = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global detector_manager
try:
from manager import UnifiedDetectionManager
detector_manager = UnifiedDetectionManager()
print("检测管理器初始化成功")
except Exception as e:
print(f"初始化失败:{str(e)}")
raise
yield
# 程序关闭时清理临时文件(可选)
print("清理临时文件...")
for file in UPLOAD_DIR.glob("*"):
try:
if file.is_file():
file.unlink()
except Exception as e:
print(f"清理文件 {file} 失败:{e}")
app = FastAPI(lifespan=lifespan, title="目标检测与投标文件生成API", version="1.0.0")
# 配置跨域请求
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class DetectionRequest(BaseModel):
type: str
url: HttpUrl
class DetectionProcessRequest(BaseModel):
url: HttpUrl
class TenderGenerateResponse(BaseModel):
"""投标文件生成响应模型"""
status: str
message: str
relative_path: str
file_name: str
file_size: Optional[int] = None
@app.post("/detect_image", response_model=DetectionResponse)
async def run_detection_image(request: DetectionRequest):
# 解析检测类型
requested_types = {t.strip().lower() for t in request.type.split(',') if t.strip()}
print(f"请求的检测类型: {requested_types}")
if not requested_types:
raise HTTPException(status_code=400, detail="未指定检测类型")
# 下载图片
try:
response = requests.get(str(request.url), timeout=15)
response.raise_for_status()
# 获取图片尺寸
with Image.open(io.BytesIO(response.content)) as img:
img_size = [img.width, img.height]
# 创建临时文件
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
temp_file.write(response.content)
temp_path = temp_file.name
except Exception as e:
raise HTTPException(status_code=400, detail=f"图片处理失败:{str(e)}")
# 执行检测
results = []
errors = []
try:
detection_results = detector_manager.detect(temp_path, ",".join(requested_types))
if detection_results:
results = detection_results
except Exception as e:
errors.append(f"检测失败:{str(e)}")
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
return {
"hasTarget": 1 if results else 0,
"originalImgSize": img_size,
"targets": results,
"processing_errors": errors
}
@app.post("/detect_process", response_model=DetectionResponse)
async def run_detection_process(request: DetectionProcessRequest):
return detect_large_image_from_url(str(request.url))
@app.get("/supported_types")
async def get_supported_types():
if detector_manager:
info = detector_manager.get_available_info()
return {
"supported_types": info["supported_types"],
}
return {"supported_types": []}
@app.post("/generate_tender", response_model=TenderGenerateResponse)
async def generate_tender_file(file: UploadFile = File(...)):
"""
上传招标文件Word格式生成投标文件
支持的文件格式:.docx
返回生成文件的相对路径,用于下载接口
"""
# 初始化upload_path变量避免UnboundLocalError
upload_path = None
# 验证文件类型
if not file.filename.endswith((".docx", ".doc")):
raise HTTPException(
status_code=400,
detail="仅支持Word文件.docx或.doc格式"
)
try:
# 保存上传的文件
timestamp = int(time.time())
file_ext = Path(file.filename).suffix
upload_filename = f"tender_{timestamp}{file_ext}"
upload_path = UPLOAD_DIR / upload_filename
# 保存文件内容
with open(upload_path, "wb") as f:
shutil.copyfileobj(file.file, f)
print(f"已接收上传文件:{upload_path}")
# 生成输出文件名和路径(使用绝对路径)
output_filename = f"投标文件_生成版_{timestamp}.docx"
output_path = OUTPUT_DIR / output_filename
output_path_abs = str(output_path.resolve()) # 绝对路径
# 调用AI_Agent的统一入口方法仅这一个调用
print("开始生成投标文件...")
generate_success = generate_tender_from_input(
input_word_path=str(upload_path.resolve()), # 传入绝对路径
output_word_path=output_path_abs
)
if not generate_success:
raise Exception("投标文件生成核心流程执行失败")
# 验证生成的文件是否存在
if not output_path.exists() or not output_path.is_file():
raise Exception("投标文件生成后未找到目标文件")
# 计算文件大小
file_size = output_path.stat().st_size
# 修复使用os.path.relpath计算相对路径更灵活
try:
# 计算相对于项目根目录的相对路径
relative_path = os.path.relpath(output_path_abs, str(BASE_DIR))
# 统一路径分隔符为 '/'避免Windows和Linux差异
relative_path = relative_path.replace(os.sep, '/')
except Exception as e:
# 异常情况下直接返回文件名(降级方案)
relative_path = output_filename
print(f"计算相对路径失败,使用降级方案:{e}")
print(f"投标文件生成成功:{output_path_abs}")
print(f"相对路径:{relative_path}")
return {
"status": "success",
"message": "投标文件生成完成",
"relative_path": relative_path,
"file_name": output_filename,
"file_size": file_size
}
except Exception as e:
error_msg = f"生成投标文件失败:{str(e)}"
print(error_msg)
raise HTTPException(
status_code=500,
detail=error_msg
)
finally:
# 关闭上传文件流
await file.close()
# 清理上传的原始文件,节省空间
if upload_path and upload_path.exists():
try:
upload_path.unlink()
print(f"已清理上传文件:{upload_path}")
except Exception as e:
print(f"清理上传文件失败:{e}")
@app.get("/download_file")
async def download_generated_file(relative_path: str):
"""
根据相对路径下载生成的投标文件
Args:
relative_path: 生成文件的相对路径(从/generate_tender接口获取
"""
try:
# 修复:将相对路径转换为绝对路径
# 统一路径分隔符
relative_path = relative_path.replace('/', os.sep)
# 拼接项目根目录得到绝对路径
abs_path = BASE_DIR / relative_path
abs_path = abs_path.resolve() # 解析完整路径
# 验证路径是否在允许的输出目录内
if not abs_path.is_relative_to(OUTPUT_DIR.resolve()):
raise HTTPException(
status_code=403,
detail="访问禁止:文件路径不在允许范围内"
)
if not abs_path.exists() or not abs_path.is_file():
raise HTTPException(
status_code=404,
detail="文件不存在或已被删除"
)
# 返回文件下载响应
return FileResponse(
path=abs_path,
filename=abs_path.name,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
)
except HTTPException:
raise
except Exception as e:
print(f"文件下载失败:{str(e)}")
raise HTTPException(
status_code=500,
detail=f"文件下载失败:{str(e)}"
)
@app.get("/list_generated_files")
async def list_generated_files():
"""列出所有生成的投标文件(可选接口)"""
files = []
for file in OUTPUT_DIR.glob("*.docx"):
# 计算相对路径
try:
relative_path = os.path.relpath(str(file.resolve()), str(BASE_DIR))
relative_path = relative_path.replace(os.sep, '/')
except:
relative_path = file.name
files.append({
"file_name": file.name,
"relative_path": relative_path,
"file_size": file.stat().st_size,
"created_time": file.stat().st_ctime
})
return {
"total": len(files),
"files": files
}
if __name__ == "__main__":
import argparse
# 补充必要的导入(在主函数中导入,避免启动时依赖冲突)
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--reload", action="store_true")
args = parser.parse_args()
uvicorn.run("main:app", host=args.host, port=args.port, reload=args.reload)