Files
VisDrone-Version/api_server.py
2025-08-05 16:55:45 +08:00

390 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import shutil
from typing import Dict, Optional, Any
from fastapi import APIRouter, File, UploadFile, HTTPException, Form
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from result import Response
# 假设 rfdetr_core.py 在同一目录下或 PYTHONPATH 中
from rfdetr_core import RFDETRDetector
# --- Global Variables and Configuration ---
model_management_router = APIRouter()
BASE_MODEL_DIR = "models"
BASE_CONFIG_DIR = "configs"
os.makedirs(BASE_MODEL_DIR, exist_ok=True)
os.makedirs(BASE_CONFIG_DIR, exist_ok=True)
# 用于存储当前激活的检测器实例和可用模型信息
current_detector: Optional[RFDETRDetector] = None
current_model_identifier: Optional[str] = None
available_models_info: Dict[str, Dict] = {} # 存储模型标识符及其配置内容
def load_config(config_name: str) -> Optional[Dict]:
"""加载指定的JSON配置文件。"""
config_path = os.path.join(BASE_CONFIG_DIR, f"{config_name}.json")
if os.path.exists(config_path):
try:
with open(config_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"错误:加载配置文件 '{config_path}' 失败: {e}")
return None
return None
def initialize_detector(model_identifier: str) -> Optional[RFDETRDetector]:
"""根据模型标识符初始化检测器。"""
global current_detector, current_model_identifier
try:
config = available_models_info.get(model_identifier)
if not config:
print(f"错误:未找到模型 '{model_identifier}' 的缓存配置,尝试从磁盘加载。") # 更明确的日志
config_data = load_config(model_identifier)
if not config_data:
raise FileNotFoundError(f"配置文件 {model_identifier}.json 未找到或无法加载。")
available_models_info[model_identifier] = config_data
config = config_data
model_filename = config.get('model_pth_filename')
if not model_filename:
raise ValueError(f"配置文件 '{model_identifier}.json' 中缺少 'model_pth_filename' 字段。")
model_full_path = os.path.join(BASE_MODEL_DIR, model_filename)
if not os.path.exists(model_full_path):
raise FileNotFoundError(f"模型文件 '{model_full_path}' (在 '{model_identifier}.json' 中指定) 不存在。")
print(f"尝试使用配置 '{model_identifier}.json' 和模型 '{model_full_path}' 初始化检测器...")
detector = RFDETRDetector(config_name=model_identifier, base_model_dir=BASE_MODEL_DIR,
base_config_dir=BASE_CONFIG_DIR)
print(f"检测器 '{model_identifier}' 初始化成功。")
current_detector = detector
current_model_identifier = model_identifier
# 通知 DataPusher 更新其检测器实例
try:
from data_pusher import get_data_pusher_instance # 动态导入以避免潜在的循环依赖问题
pusher = get_data_pusher_instance()
if pusher:
print(f"通知 DataPusher 更新其检测器实例为: {model_identifier}")
pusher.update_detector_instance(current_detector)
# else:
# 如果 pusher 为 None可能是因为它尚未在主应用启动时被完全初始化
# data_pusher 模块内部的 initialize_data_pusher 负责记录其自身初始化状态
# print(f"DataPusher 实例尚未初始化 (或初始化失败),无法更新其检测器。")
except ImportError:
print(
"警告: 无法导入 data_pusher 模块以更新 DataPusher 的检测器实例。如果您不使用数据推送功能,此消息可忽略。")
except Exception as e_pusher:
print(f"警告: 通知 DataPusher 更新检测器时发生意外错误: {e_pusher}")
return detector
except FileNotFoundError as e:
print(f"初始化检测器 '{model_identifier}' 失败 (文件未找到): {e}")
if current_model_identifier == model_identifier:
current_detector = None
current_model_identifier = None
raise HTTPException(status_code=404, detail=str(e)) # 保持 404让上层知道是文件问题
except ValueError as e: # 捕获配置字段缺失等问题
print(f"初始化检测器 '{model_identifier}' 失败 (配置值错误): {e}")
if current_model_identifier == model_identifier:
current_detector = None
current_model_identifier = None
raise HTTPException(status_code=400, detail=str(e)) # 配置问题用 400
except Exception as e: # 其他来自 RFDETRDetector 内部的初始化错误
print(f"初始化检测器 '{model_identifier}' 失败 (内部错误): {e}")
# import traceback
# traceback.print_exc() # 在服务器日志中打印完整堆栈,方便调试
if current_model_identifier == model_identifier:
current_detector = None
current_model_identifier = None
if model_identifier in available_models_info:
del available_models_info[model_identifier]
# 这些通常是服务器端问题或模型/库的兼容性问题所以用500
raise HTTPException(status_code=500, detail=f"检测器 '{model_identifier}' 内部初始化失败: {e}")
def get_active_detector() -> Optional[RFDETRDetector]:
"""获取当前激活的检测器实例。"""
return current_detector
def get_active_model_identifier() -> Optional[str]:
"""获取当前激活的模型标识符。"""
return current_model_identifier
def scan_and_load_available_models():
"""扫描配置目录,加载所有有效的模型配置。"""
global available_models_info
available_models_info = {}
# print(f"扫描目录 '{BASE_CONFIG_DIR}' 以查找配置文件...") # 减少日志冗余
if not os.path.exists(BASE_CONFIG_DIR):
# print(f"配置目录 '{BASE_CONFIG_DIR}' 不存在,跳过扫描。")
return
for filename in os.listdir(BASE_CONFIG_DIR):
if filename.endswith(".json"):
model_identifier = filename[:-5]
# print(f"找到配置文件: {filename},模型标识符: {model_identifier}")
config_data = load_config(model_identifier)
if config_data:
model_pth = config_data.get('model_pth_filename')
model_full_path = os.path.join(BASE_MODEL_DIR, model_pth) if model_pth else None
if model_pth and os.path.exists(model_full_path):
available_models_info[model_identifier] = config_data
# print(f"模型配置 '{model_identifier}' 加载成功。")
# elif not model_pth:
# print(f"警告:模型 '{model_identifier}' 的配置文件中未指定 'model_pth_filename'。")
# else:
# print(f"警告:模型 '{model_identifier}' 的模型文件 '{model_full_path}' 未找到。")
# else:
# print(f"警告:无法加载模型 '{model_identifier}' 的配置文件。")
print(f"可用模型配置已扫描/更新: {list(available_models_info.keys())}")
# --- Extracted Startup Logic (to be called by the main app) ---
async def initialize_default_model_on_startup():
"""应用启动时,扫描并加载可用模型,尝试激活第一个。"""
print("执行模型管理模块的启动初始化...")
scan_and_load_available_models()
global current_model_identifier, current_detector
if available_models_info:
first_model = sorted(list(available_models_info.keys()))[0]
print(f"尝试将第一个可用模型 '{first_model}' 设置为活动模型。")
try:
initialize_detector(first_model)
print(f"默认模型 '{current_model_identifier}' 加载并激活成功。")
except HTTPException as e: # 捕获 initialize_detector 抛出的HTTPException
print(f"加载默认模型 '{first_model}' 失败: {e.detail} (状态码: {e.status_code})")
current_model_identifier = None
current_detector = None
# 不需要再捕获 Exception as e因为 initialize_detector 已经处理并转换为 HTTPException
else:
print("没有可用的模型配置,服务器启动但无默认模型激活。")
current_model_identifier = None
current_detector = None
# --- Pydantic Models for Request/Response ---
class ModelIdentifier(BaseModel):
model_identifier: str
# Define a standard response model for OpenAPI documentation
class StandardResponse(BaseModel):
code: int
data: Optional[Any] = None
message: str
# --- API Endpoints ---
@model_management_router.post("/upload_model_and_config/", response_model=StandardResponse)
async def upload_model_and_config(
model_identifier_form: str = Form(..., alias="model_identifier"),
config_file: UploadFile = File(...),
model_file: UploadFile = File(...)
):
"""
上传模型文件 (.pth) 和配置文件 (.json)。
- **model_identifier**: 模型的唯一名称 (例如 "人车检测")。配置文件将以此名称保存 (e.g., "人车检测.json")。
- **config_file**: JSON 配置文件。
- **model_file**: Pytorch 模型文件 (.pth)。其文件名必须与配置文件中 'model_pth_filename' 字段指定的一致。
"""
config_filename = f"{model_identifier_form}.json"
config_path = os.path.join(BASE_CONFIG_DIR, config_filename)
model_path = None # 在 try 块外部声明,以便 finally 和 except 中可用
global current_model_identifier, current_detector # current_detector 实际上在这里主要由 initialize_detector 设置
try:
print(f"开始上传模型 '{model_identifier_form}'...")
config_content = await config_file.read()
try:
config_data = json.loads(config_content.decode('utf-8'))
except json.JSONDecodeError:
# raise HTTPException(status_code=400, detail="无效的JSON配置文件格式请检查JSON语法。")
return JSONResponse(status_code=400,
content=Response.error(message="无效的JSON配置文件格式请检查JSON语法。", code=400))
except UnicodeDecodeError:
# raise HTTPException(status_code=400, detail="配置文件编码错误请确保为UTF-8编码。")
return JSONResponse(status_code=400,
content=Response.error(message="配置文件编码错误请确保为UTF-8编码。", code=400))
if 'model_pth_filename' not in config_data:
# raise HTTPException(status_code=400, detail="配置文件中必须包含 'model_pth_filename' 字段。")
return JSONResponse(status_code=400,
content=Response.error(message="配置文件中必须包含 'model_pth_filename' 字段。",
code=400))
target_model_filename = config_data['model_pth_filename']
if not isinstance(target_model_filename, str) or not target_model_filename.endswith(".pth"):
# raise HTTPException(status_code=400, detail="配置文件中的 'model_pth_filename' 必须是有效的 .pth 文件名字符串。")
return JSONResponse(status_code=400, content=Response.error(
message="配置文件中的 'model_pth_filename' 必须是有效的 .pth 文件名字符串。", code=400))
if model_file.filename != target_model_filename:
# raise HTTPException(
# status_code=400,
# detail=f"上传的模型文件名 '{model_file.filename}' 与配置文件中指定的 '{target_model_filename}' 不匹配。"
# )
return JSONResponse(
status_code=400,
content=Response.error(
message=f"上传的模型文件名 '{model_file.filename}' 与配置文件中指定的 '{target_model_filename}' 不匹配。",
code=400
)
)
with open(config_path, 'wb') as f:
f.write(config_content)
print(f"配置文件 '{config_path}' 已保存。")
model_path = os.path.join(BASE_MODEL_DIR, target_model_filename)
with open(model_path, "wb") as buffer:
shutil.copyfileobj(model_file.file, buffer)
print(f"模型文件 '{model_path}' 已保存。")
# 在尝试初始化之前,将配置数据添加到 available_models_info
# 这样 initialize_detector 即使在缓存未命中的情况下也能通过 load_config 找到它
available_models_info[model_identifier_form] = config_data
try:
print(f"尝试初始化并激活新上传的模型: '{model_identifier_form}'")
initialize_detector(model_identifier_form) # 此函数会设置全局的 current_detector 和 current_model_identifier
print(f"新上传的模型 '{model_identifier_form}' 验证并激活成功。")
# 成功后available_models_info 已包含此模型current_detector 和 current_model_identifier 已更新
# 无需在此处再次调用 scan_and_load_available_models()因为它会重新扫描所有可能覆盖内存中的一些状态或引入不必要的IO
except HTTPException as e: # 从 initialize_detector 抛出的错误
print(f"初始化新上传的模型 '{model_identifier_form}' 失败: {e.detail}")
# 清理已保存的文件和内存中的条目
if os.path.exists(config_path): os.remove(config_path)
if model_path and os.path.exists(model_path): os.remove(model_path)
if model_identifier_form in available_models_info: del available_models_info[model_identifier_form]
# scan_and_load_available_models() # 失败后重新扫描是好的,以确保 available_models_info 准确
# 但如果 initialize_detector 内部已经删除了 available_models_info 中的条目,这里可能不需要
# 为保持一致性,如果上面删除了,这里重新扫描一下比较稳妥
scan_and_load_available_models()
# 重新抛出为 422并包含原始错误信息
# raise HTTPException(status_code=422,
# detail=f"模型 '{model_identifier_form}' 已上传但初始化失败,请检查模型或配置内容。原始错误: {e.detail}")
return JSONResponse(status_code=422,
content=Response.error(
message=f"模型 '{model_identifier_form}' 已上传但初始化失败,请检查模型或配置内容。原始错误: {e.detail}",
code=422
))
# 如果成功,确保全局模型列表是最新的(虽然 initialize_detector 已更新了 current_*,但列表可能需要刷新以供 /available_models 使用)
# scan_and_load_available_models() # 移除这个,因为当前模型已激活,列表会在下次调用 /available_models 时刷新
# return UploadResponse(
# message=f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。",
# model_identifier=model_identifier_form,
# config_filename=config_filename,
# model_filename=target_model_filename
# )
return Response.success(data={
"message": f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。", # Message is also in the wrapper
"model_identifier": model_identifier_form,
"config_filename": config_filename,
"model_filename": target_model_filename
}, message=f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。")
except HTTPException as e: # 捕获直接由 FastAPI 验证或其他地方抛出的 HTTPException
# This might catch exceptions from initialize_detector if they are not caught internally by the above try-except for initialize_detector
return JSONResponse(status_code=e.status_code, content=Response.error(message=e.detail, code=e.status_code))
except Exception as e:
identifier_for_log = model_identifier_form if 'model_identifier_form' in locals() else "unknown"
print(f"上传模型 '{identifier_for_log}' 过程中发生意外的严重错误: {e}")
# import traceback
# traceback.print_exc()
# 尝试清理,以防文件部分写入
if config_path and os.path.exists(config_path):
os.remove(config_path)
if model_path and os.path.exists(model_path):
os.remove(model_path)
if 'model_identifier_form' in locals() and model_identifier_form in available_models_info:
del available_models_info[model_identifier_form]
scan_and_load_available_models() # 出错后务必刷新列表
# raise HTTPException(status_code=500, detail=f"上传模型 '{identifier_for_log}' 过程中发生内部服务器错误: {str(e)}")
return JSONResponse(status_code=500, content=Response.error(
message=f"上传模型 '{identifier_for_log}' 过程中发生内部服务器错误: {str(e)}", code=500))
finally:
if config_file: await config_file.close()
if model_file: await model_file.close()
@model_management_router.post("/select_model/", response_model=StandardResponse)
# @model_management_router.post("/select_model/{model_identifier_path}", response_model=StandardResponse)
async def select_model(model_identifier_path: str):
"""
根据提供的标识符选择并激活一个已上传的模型。
路径参数 `model_identifier_path` 即为模型的唯一名称。
"""
global current_model_identifier, current_detector
# 总是先扫描以获取最新的可用模型列表,以防外部文件更改
scan_and_load_available_models()
if model_identifier_path not in available_models_info:
# raise HTTPException(status_code=404, detail=f"模型标识符 '{model_identifier_path}' 未在可用配置中找到。请确保已上传或名称正确。")
return JSONResponse(status_code=404, content=Response.error(
message=f"模型标识符 '{model_identifier_path}' 未在可用配置中找到。请确保已上传或名称正确。", code=404))
if current_model_identifier == model_identifier_path and current_detector is not None:
print(f"模型 '{model_identifier_path}' 已经是活动模型。")
# return SelectModelResponse(
# message=f"模型 '{model_identifier_path}' 已是当前活动模型。",
# active_model=current_model_identifier
# )
return Response.success(data={
"active_model": current_model_identifier
}, message=f"模型 '{model_identifier_path}' 已是当前活动模型。")
try:
print(f"尝试激活模型: {model_identifier_path}")
initialize_detector(model_identifier_path)
print(f"模型 '{current_model_identifier}' 已成功激活。")
# return SelectModelResponse(
# message=f"模型 '{model_identifier_path}' 启动成功。",
# active_model=current_model_identifier
# )
return Response.success(data={
"active_model": current_model_identifier
}, message=f"模型 '{model_identifier_path}' 启动成功。")
except HTTPException as e:
# initialize_detector 内部已经处理了 current_model_identifier 和 current_detector 的清理
# 以及 available_models_info 中对应条目的移除(如果是其内部错误)
print(f"激活模型 '{model_identifier_path}' 失败: {e.detail} (原始状态码: {e.status_code})")
# 此处不再需要手动清理 available_models_info因为 initialize_detector 如果因内部错误(非文件找不到)
# 导致无法实例化 RFDETRDetector它会自己删除条目。
# 如果是文件找不到 (404 from initialize_detector),条目可能还在,但下次 scan 会处理。
scan_and_load_available_models() # 确保 select 失败后,列表也是最新的
# raise HTTPException(status_code=e.status_code, detail=e.detail) # 重新抛出原始的 HTTPException
return JSONResponse(status_code=e.status_code, content=Response.error(message=e.detail, code=e.status_code))
# 不需要再捕获 Exception as e因为 initialize_detector 已经处理并转换为 HTTPException
@model_management_router.get("/available_models/", response_model=StandardResponse)
async def get_available_models():
"""列出所有当前可用的模型标识符。"""
scan_and_load_available_models()
# return list(available_models_info.keys())
return Response.success(data=list(available_models_info.keys()), message="成功获取可用模型列表。")
@model_management_router.get("/current_model/", response_model=StandardResponse)
async def get_current_model_endpoint():
"""获取当前激活的模型标识符。"""
# return current_model_identifier
if current_model_identifier:
return Response.success(data=current_model_identifier, message="成功获取当前激活的模型。")
else:
return Response.success(data=None, message="当前没有激活的模型。")