添加ai智能体标书接口,修改ai_agent仅暴露一个方法供外部调用
This commit is contained in:
177
main.py
177
main.py
@ -1,18 +1,30 @@
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
import tempfile
|
||||
import shutil
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
import uvicorn
|
||||
from PIL import Image
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
from AI_Agent import generate_tender_from_input
|
||||
from config import DetectionResponse
|
||||
from process import detect_large_image_from_url
|
||||
|
||||
# 配置文件存储路径
|
||||
UPLOAD_DIR = Path("uploaded_files")
|
||||
OUTPUT_DIR = Path("generated_tenders")
|
||||
UPLOAD_DIR.mkdir(exist_ok=True)
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# 全局检测管理器
|
||||
detector_manager = None
|
||||
|
||||
@ -28,9 +40,17 @@ async def lifespan(app: FastAPI):
|
||||
print(f"初始化失败:{str(e)}")
|
||||
raise
|
||||
yield
|
||||
# 程序关闭时清理临时文件(可选)
|
||||
print("清理临时文件...")
|
||||
for file in UPLOAD_DIR.glob("*"):
|
||||
try:
|
||||
if file.is_file():
|
||||
file.unlink()
|
||||
except Exception as e:
|
||||
print(f"清理文件 {file} 失败:{e}")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan, title="目标检测API", version="1.0.0")
|
||||
app = FastAPI(lifespan=lifespan, title="目标检测与投标文件生成API", version="1.0.0")
|
||||
|
||||
# 配置跨域请求
|
||||
app.add_middleware(
|
||||
@ -51,6 +71,15 @@ class DetectionProcessRequest(BaseModel):
|
||||
url: HttpUrl
|
||||
|
||||
|
||||
class TenderGenerateResponse(BaseModel):
|
||||
"""投标文件生成响应模型"""
|
||||
status: str
|
||||
message: str
|
||||
relative_path: str
|
||||
file_name: str
|
||||
file_size: Optional[int] = None
|
||||
|
||||
|
||||
@app.post("/detect_image", response_model=DetectionResponse)
|
||||
async def run_detection_image(request: DetectionRequest):
|
||||
# 解析检测类型
|
||||
@ -114,12 +143,154 @@ async def get_supported_types():
|
||||
return {"supported_types": []}
|
||||
|
||||
|
||||
@app.post("/generate_tender", response_model=TenderGenerateResponse)
|
||||
async def generate_tender_file(file: UploadFile = File(...)):
|
||||
"""
|
||||
上传招标文件(Word格式),生成投标文件
|
||||
支持的文件格式:.docx
|
||||
返回生成文件的相对路径,用于下载接口
|
||||
"""
|
||||
upload_path = None
|
||||
|
||||
# 验证文件类型
|
||||
if not file.filename.endswith((".docx", ".doc")):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="仅支持Word文件(.docx或.doc格式)"
|
||||
)
|
||||
|
||||
try:
|
||||
# 保存上传的文件
|
||||
timestamp = int(time.time())
|
||||
file_ext = Path(file.filename).suffix
|
||||
upload_filename = f"tender_{timestamp}{file_ext}"
|
||||
upload_path = UPLOAD_DIR / upload_filename
|
||||
|
||||
# 保存文件内容
|
||||
with open(upload_path, "wb") as f:
|
||||
shutil.copyfileobj(file.file, f)
|
||||
|
||||
print(f"已接收上传文件:{upload_path}")
|
||||
|
||||
# 生成输出文件名和路径
|
||||
output_filename = f"投标文件_生成版_{timestamp}.docx"
|
||||
output_path = OUTPUT_DIR / output_filename
|
||||
|
||||
# 调用AI_Agent的统一入口方法(仅这一个调用)
|
||||
print("开始生成投标文件...")
|
||||
generate_success = generate_tender_from_input(
|
||||
input_word_path=str(upload_path),
|
||||
output_word_path=str(output_path)
|
||||
)
|
||||
|
||||
if not generate_success:
|
||||
raise Exception("投标文件生成核心流程执行失败")
|
||||
|
||||
# 验证生成的文件是否存在
|
||||
if not output_path.exists() or not output_path.is_file():
|
||||
raise Exception("投标文件生成后未找到目标文件")
|
||||
|
||||
# 计算文件大小
|
||||
file_size = output_path.stat().st_size
|
||||
|
||||
# 构建相对路径(相对于项目根目录)
|
||||
relative_path = str(output_path.relative_to(Path.cwd()))
|
||||
|
||||
print(f"投标文件生成成功:{output_path}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "投标文件生成完成",
|
||||
"relative_path": relative_path,
|
||||
"file_name": output_filename,
|
||||
"file_size": file_size
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成投标文件失败:{str(e)}"
|
||||
print(error_msg)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=error_msg
|
||||
)
|
||||
finally:
|
||||
# 关闭上传文件流
|
||||
await file.close()
|
||||
# 清理上传的原始文件,节省空间
|
||||
if upload_path and upload_path.exists():
|
||||
try:
|
||||
upload_path.unlink()
|
||||
print(f"已清理上传文件:{upload_path}")
|
||||
except Exception as e:
|
||||
print(f"清理上传文件失败:{e}")
|
||||
|
||||
|
||||
@app.get("/download_file")
|
||||
async def download_generated_file(relative_path: str):
|
||||
"""
|
||||
根据相对路径下载生成的投标文件
|
||||
Args:
|
||||
relative_path: 生成文件的相对路径(从/generate_tender接口获取)
|
||||
"""
|
||||
# 解析绝对路径,防止路径穿越攻击
|
||||
try:
|
||||
abs_path = Path(relative_path).resolve()
|
||||
|
||||
# 验证路径是否在允许的输出目录内
|
||||
if not abs_path.is_relative_to(OUTPUT_DIR.resolve()):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="访问禁止:文件路径不在允许范围内"
|
||||
)
|
||||
|
||||
if not abs_path.exists() or not abs_path.is_file():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="文件不存在或已被删除"
|
||||
)
|
||||
|
||||
# 返回文件下载响应
|
||||
return FileResponse(
|
||||
path=abs_path,
|
||||
filename=abs_path.name,
|
||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"文件下载失败:{str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"文件下载失败:{str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.get("/list_generated_files")
|
||||
async def list_generated_files():
|
||||
"""列出所有生成的投标文件(可选接口)"""
|
||||
files = []
|
||||
for file in OUTPUT_DIR.glob("*.docx"):
|
||||
files.append({
|
||||
"file_name": file.name,
|
||||
"relative_path": str(file.relative_to(Path.cwd())),
|
||||
"file_size": file.stat().st_size,
|
||||
"created_time": file.stat().st_ctime
|
||||
})
|
||||
return {
|
||||
"total": len(files),
|
||||
"files": files
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
# 补充必要的导入(在主函数中导入,避免启动时依赖冲突)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--reload", action="store_true")
|
||||
args = parser.parse_args()
|
||||
uvicorn.run("main:app", host=args.host, port=args.port, reload=args.reload)
|
||||
|
||||
uvicorn.run("main:app", host=args.host, port=args.port, reload=args.reload)
|
||||
Reference in New Issue
Block a user