127 lines
3.4 KiB
Python
127 lines
3.4 KiB
Python
import io
|
||
import os
|
||
import tempfile
|
||
from contextlib import asynccontextmanager
|
||
|
||
import requests
|
||
import uvicorn
|
||
from PIL import Image
|
||
from fastapi import FastAPI, HTTPException
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel, HttpUrl
|
||
|
||
from config import DetectionResponse
|
||
from process import detect_large_image_from_url
|
||
|
||
# 全局检测管理器
|
||
detector_manager = None
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
global detector_manager
|
||
try:
|
||
from manager import UnifiedDetectionManager
|
||
detector_manager = UnifiedDetectionManager()
|
||
print("检测管理器初始化成功")
|
||
except Exception as e:
|
||
print(f"初始化失败:{str(e)}")
|
||
raise
|
||
yield
|
||
|
||
|
||
app = FastAPI(lifespan=lifespan, title="目标检测API", version="1.0.0")
|
||
|
||
# 配置跨域请求
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
class DetectionRequest(BaseModel):
|
||
type: str
|
||
url: HttpUrl
|
||
|
||
|
||
class DetectionProcessRequest(BaseModel):
|
||
url: HttpUrl
|
||
|
||
|
||
@app.post("/detect_image", response_model=DetectionResponse)
|
||
async def run_detection_image(request: DetectionRequest):
|
||
# 解析检测类型
|
||
requested_types = {t.strip().lower() for t in request.type.split(',') if t.strip()}
|
||
print(f"请求的检测类型: {requested_types}")
|
||
if not requested_types:
|
||
raise HTTPException(status_code=400, detail="未指定检测类型")
|
||
|
||
# 下载图片
|
||
try:
|
||
response = requests.get(str(request.url), timeout=15)
|
||
response.raise_for_status()
|
||
|
||
# 获取图片尺寸
|
||
with Image.open(io.BytesIO(response.content)) as img:
|
||
img_size = [img.width, img.height]
|
||
|
||
# 创建临时文件
|
||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
|
||
temp_file.write(response.content)
|
||
temp_path = temp_file.name
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"图片处理失败:{str(e)}")
|
||
|
||
# 执行检测
|
||
results = []
|
||
errors = []
|
||
|
||
try:
|
||
detection_results = detector_manager.detect(temp_path, ",".join(requested_types))
|
||
if detection_results:
|
||
results = detection_results
|
||
except Exception as e:
|
||
errors.append(f"检测失111111111111111111111败:{str(e)}")
|
||
finally:
|
||
# 清理临时文件
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
|
||
return {
|
||
"hasTarget": 1 if results else 0,
|
||
"originalImgSize": img_size,
|
||
"targets": results,
|
||
"processing_errors": errors
|
||
}
|
||
|
||
|
||
@app.post("/detect_process", response_model=DetectionResponse)
|
||
async def run_detection_process(request: DetectionProcessRequest):
|
||
return detect_large_image_from_url(str(request.url))
|
||
|
||
|
||
@app.get("/supported_types")
|
||
async def get_supported_types():
|
||
if detector_manager:
|
||
info = detector_manager.get_available_info()
|
||
return {
|
||
"supported_types": info["supported_types"],
|
||
}
|
||
return {"supported_types": []}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--host", default="0.0.0.0")
|
||
parser.add_argument("--port", type=int, default=8000)
|
||
parser.add_argument("--reload", action="store_true")
|
||
args = parser.parse_args()
|
||
uvicorn.run("main:app", host=args.host, port=args.port, reload=args.reload)
|