126 lines
3.4 KiB
Python
126 lines
3.4 KiB
Python
|
|
import io
|
||
|
|
import os
|
||
|
|
import tempfile
|
||
|
|
from contextlib import asynccontextmanager
|
||
|
|
|
||
|
|
import requests
|
||
|
|
import uvicorn
|
||
|
|
from PIL import Image
|
||
|
|
from fastapi import FastAPI, HTTPException
|
||
|
|
from fastapi.middleware.cors import CORSMiddleware
|
||
|
|
from pydantic import BaseModel, HttpUrl
|
||
|
|
|
||
|
|
from config import DetectionResponse
|
||
|
|
from process import detect_large_image_from_url
|
||
|
|
|
||
|
|
# 全局检测管理器
|
||
|
|
detector_manager = None
|
||
|
|
|
||
|
|
|
||
|
|
@asynccontextmanager
|
||
|
|
async def lifespan(app: FastAPI):
|
||
|
|
global detector_manager
|
||
|
|
try:
|
||
|
|
from manager import UnifiedDetectionManager
|
||
|
|
detector_manager = UnifiedDetectionManager()
|
||
|
|
print("检测管理器初始化成功")
|
||
|
|
except Exception as e:
|
||
|
|
print(f"初始化失败:{str(e)}")
|
||
|
|
raise
|
||
|
|
yield
|
||
|
|
|
||
|
|
|
||
|
|
app = FastAPI(lifespan=lifespan, title="目标检测API", version="1.0.0")
|
||
|
|
|
||
|
|
# 配置跨域请求
|
||
|
|
app.add_middleware(
|
||
|
|
CORSMiddleware,
|
||
|
|
allow_origins=["*"],
|
||
|
|
allow_credentials=True,
|
||
|
|
allow_methods=["*"],
|
||
|
|
allow_headers=["*"],
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class DetectionRequest(BaseModel):
|
||
|
|
type: str
|
||
|
|
url: HttpUrl
|
||
|
|
|
||
|
|
|
||
|
|
class DetectionProcessRequest(BaseModel):
|
||
|
|
url: HttpUrl
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/detect_image", response_model=DetectionResponse)
|
||
|
|
async def run_detection_image(request: DetectionRequest):
|
||
|
|
# 解析检测类型
|
||
|
|
requested_types = {t.strip().lower() for t in request.type.split(',') if t.strip()}
|
||
|
|
print(f"请求的检测类型: {requested_types}")
|
||
|
|
if not requested_types:
|
||
|
|
raise HTTPException(status_code=400, detail="未指定检测类型")
|
||
|
|
|
||
|
|
# 下载图片
|
||
|
|
try:
|
||
|
|
response = requests.get(str(request.url), timeout=15)
|
||
|
|
response.raise_for_status()
|
||
|
|
|
||
|
|
# 获取图片尺寸
|
||
|
|
with Image.open(io.BytesIO(response.content)) as img:
|
||
|
|
img_size = [img.width, img.height]
|
||
|
|
|
||
|
|
# 创建临时文件
|
||
|
|
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
|
||
|
|
temp_file.write(response.content)
|
||
|
|
temp_path = temp_file.name
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
raise HTTPException(status_code=400, detail=f"图片处理失败:{str(e)}")
|
||
|
|
|
||
|
|
# 执行检测
|
||
|
|
results = []
|
||
|
|
errors = []
|
||
|
|
|
||
|
|
try:
|
||
|
|
detection_results = detector_manager.detect(temp_path, ",".join(requested_types))
|
||
|
|
if detection_results:
|
||
|
|
results = detection_results
|
||
|
|
except Exception as e:
|
||
|
|
errors.append(f"检测失败:{str(e)}")
|
||
|
|
finally:
|
||
|
|
# 清理临时文件
|
||
|
|
if os.path.exists(temp_path):
|
||
|
|
os.remove(temp_path)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"hasTarget": 1 if results else 0,
|
||
|
|
"originalImgSize": img_size,
|
||
|
|
"targets": results,
|
||
|
|
"processing_errors": errors
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/detect_process", response_model=DetectionResponse)
|
||
|
|
async def run_detection_process(request: DetectionProcessRequest):
|
||
|
|
return detect_large_image_from_url(str(request.url))
|
||
|
|
|
||
|
|
|
||
|
|
@app.get("/supported_types")
|
||
|
|
async def get_supported_types():
|
||
|
|
if detector_manager:
|
||
|
|
info = detector_manager.get_available_info()
|
||
|
|
return {
|
||
|
|
"supported_types": info["supported_types"],
|
||
|
|
}
|
||
|
|
return {"supported_types": []}
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
import argparse
|
||
|
|
|
||
|
|
parser = argparse.ArgumentParser()
|
||
|
|
parser.add_argument("--host", default="0.0.0.0")
|
||
|
|
parser.add_argument("--port", type=int, default=8000)
|
||
|
|
parser.add_argument("--reload", action="store_true")
|
||
|
|
args = parser.parse_args()
|
||
|
|
uvicorn.run("main:app", host=args.host, port=args.port, reload=args.reload)
|