智能体加检测
This commit is contained in:
125
main.py
Normal file
125
main.py
Normal file
@ -0,0 +1,125 @@
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import requests
|
||||
import uvicorn
|
||||
from PIL import Image
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
from config import DetectionResponse
|
||||
from process import detect_large_image_from_url
|
||||
|
||||
# 全局检测管理器
|
||||
detector_manager = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global detector_manager
|
||||
try:
|
||||
from manager import UnifiedDetectionManager
|
||||
detector_manager = UnifiedDetectionManager()
|
||||
print("检测管理器初始化成功")
|
||||
except Exception as e:
|
||||
print(f"初始化失败:{str(e)}")
|
||||
raise
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan, title="目标检测API", version="1.0.0")
|
||||
|
||||
# 配置跨域请求
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
class DetectionRequest(BaseModel):
|
||||
type: str
|
||||
url: HttpUrl
|
||||
|
||||
|
||||
class DetectionProcessRequest(BaseModel):
|
||||
url: HttpUrl
|
||||
|
||||
|
||||
@app.post("/detect_image", response_model=DetectionResponse)
|
||||
async def run_detection_image(request: DetectionRequest):
|
||||
# 解析检测类型
|
||||
requested_types = {t.strip().lower() for t in request.type.split(',') if t.strip()}
|
||||
print(f"请求的检测类型: {requested_types}")
|
||||
if not requested_types:
|
||||
raise HTTPException(status_code=400, detail="未指定检测类型")
|
||||
|
||||
# 下载图片
|
||||
try:
|
||||
response = requests.get(str(request.url), timeout=15)
|
||||
response.raise_for_status()
|
||||
|
||||
# 获取图片尺寸
|
||||
with Image.open(io.BytesIO(response.content)) as img:
|
||||
img_size = [img.width, img.height]
|
||||
|
||||
# 创建临时文件
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
|
||||
temp_file.write(response.content)
|
||||
temp_path = temp_file.name
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"图片处理失败:{str(e)}")
|
||||
|
||||
# 执行检测
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
try:
|
||||
detection_results = detector_manager.detect(temp_path, ",".join(requested_types))
|
||||
if detection_results:
|
||||
results = detection_results
|
||||
except Exception as e:
|
||||
errors.append(f"检测失败:{str(e)}")
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
return {
|
||||
"hasTarget": 1 if results else 0,
|
||||
"originalImgSize": img_size,
|
||||
"targets": results,
|
||||
"processing_errors": errors
|
||||
}
|
||||
|
||||
|
||||
@app.post("/detect_process", response_model=DetectionResponse)
|
||||
async def run_detection_process(request: DetectionProcessRequest):
|
||||
return detect_large_image_from_url(str(request.url))
|
||||
|
||||
|
||||
@app.get("/supported_types")
|
||||
async def get_supported_types():
|
||||
if detector_manager:
|
||||
info = detector_manager.get_available_info()
|
||||
return {
|
||||
"supported_types": info["supported_types"],
|
||||
}
|
||||
return {"supported_types": []}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--reload", action="store_true")
|
||||
args = parser.parse_args()
|
||||
uvicorn.run("main:app", host=args.host, port=args.port, reload=args.reload)
|
||||
Reference in New Issue
Block a user