参数分离完全版
This commit is contained in:
389
api_server.py
Normal file
389
api_server.py
Normal file
@ -0,0 +1,389 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
from fastapi import APIRouter, File, UploadFile, HTTPException, Form
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from result import Response
|
||||
# 假设 rfdetr_core.py 在同一目录下或 PYTHONPATH 中
|
||||
from rfdetr_core import RFDETRDetector
|
||||
|
||||
# --- Global Variables and Configuration ---
|
||||
model_management_router = APIRouter()
|
||||
|
||||
BASE_MODEL_DIR = "models"
|
||||
BASE_CONFIG_DIR = "configs"
|
||||
os.makedirs(BASE_MODEL_DIR, exist_ok=True)
|
||||
os.makedirs(BASE_CONFIG_DIR, exist_ok=True)
|
||||
|
||||
# 用于存储当前激活的检测器实例和可用模型信息
|
||||
current_detector: Optional[RFDETRDetector] = None
|
||||
current_model_identifier: Optional[str] = None
|
||||
available_models_info: Dict[str, Dict] = {} # 存储模型标识符及其配置内容
|
||||
|
||||
|
||||
def load_config(config_name: str) -> Optional[Dict]:
|
||||
"""加载指定的JSON配置文件。"""
|
||||
config_path = os.path.join(BASE_CONFIG_DIR, f"{config_name}.json")
|
||||
if os.path.exists(config_path):
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(f"错误:加载配置文件 '{config_path}' 失败: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def initialize_detector(model_identifier: str) -> Optional[RFDETRDetector]:
|
||||
"""根据模型标识符初始化检测器。"""
|
||||
global current_detector, current_model_identifier
|
||||
try:
|
||||
config = available_models_info.get(model_identifier)
|
||||
if not config:
|
||||
print(f"错误:未找到模型 '{model_identifier}' 的缓存配置,尝试从磁盘加载。") # 更明确的日志
|
||||
config_data = load_config(model_identifier)
|
||||
if not config_data:
|
||||
raise FileNotFoundError(f"配置文件 {model_identifier}.json 未找到或无法加载。")
|
||||
available_models_info[model_identifier] = config_data
|
||||
config = config_data
|
||||
|
||||
model_filename = config.get('model_pth_filename')
|
||||
if not model_filename:
|
||||
raise ValueError(f"配置文件 '{model_identifier}.json' 中缺少 'model_pth_filename' 字段。")
|
||||
|
||||
model_full_path = os.path.join(BASE_MODEL_DIR, model_filename)
|
||||
if not os.path.exists(model_full_path):
|
||||
raise FileNotFoundError(f"模型文件 '{model_full_path}' (在 '{model_identifier}.json' 中指定) 不存在。")
|
||||
|
||||
print(f"尝试使用配置 '{model_identifier}.json' 和模型 '{model_full_path}' 初始化检测器...")
|
||||
detector = RFDETRDetector(config_name=model_identifier, base_model_dir=BASE_MODEL_DIR,
|
||||
base_config_dir=BASE_CONFIG_DIR)
|
||||
print(f"检测器 '{model_identifier}' 初始化成功。")
|
||||
current_detector = detector
|
||||
current_model_identifier = model_identifier
|
||||
|
||||
# 通知 DataPusher 更新其检测器实例
|
||||
try:
|
||||
from data_pusher import get_data_pusher_instance # 动态导入以避免潜在的循环依赖问题
|
||||
pusher = get_data_pusher_instance()
|
||||
if pusher:
|
||||
print(f"通知 DataPusher 更新其检测器实例为: {model_identifier}")
|
||||
pusher.update_detector_instance(current_detector)
|
||||
# else:
|
||||
# 如果 pusher 为 None,可能是因为它尚未在主应用启动时被完全初始化
|
||||
# data_pusher 模块内部的 initialize_data_pusher 负责记录其自身初始化状态
|
||||
# print(f"DataPusher 实例尚未初始化 (或初始化失败),无法更新其检测器。")
|
||||
except ImportError:
|
||||
print(
|
||||
"警告: 无法导入 data_pusher 模块以更新 DataPusher 的检测器实例。如果您不使用数据推送功能,此消息可忽略。")
|
||||
except Exception as e_pusher:
|
||||
print(f"警告: 通知 DataPusher 更新检测器时发生意外错误: {e_pusher}")
|
||||
|
||||
return detector
|
||||
except FileNotFoundError as e:
|
||||
print(f"初始化检测器 '{model_identifier}' 失败 (文件未找到): {e}")
|
||||
if current_model_identifier == model_identifier:
|
||||
current_detector = None
|
||||
current_model_identifier = None
|
||||
raise HTTPException(status_code=404, detail=str(e)) # 保持 404,让上层知道是文件问题
|
||||
except ValueError as e: # 捕获配置字段缺失等问题
|
||||
print(f"初始化检测器 '{model_identifier}' 失败 (配置值错误): {e}")
|
||||
if current_model_identifier == model_identifier:
|
||||
current_detector = None
|
||||
current_model_identifier = None
|
||||
raise HTTPException(status_code=400, detail=str(e)) # 配置问题用 400
|
||||
except Exception as e: # 其他来自 RFDETRDetector 内部的初始化错误
|
||||
print(f"初始化检测器 '{model_identifier}' 失败 (内部错误): {e}")
|
||||
# import traceback
|
||||
# traceback.print_exc() # 在服务器日志中打印完整堆栈,方便调试
|
||||
if current_model_identifier == model_identifier:
|
||||
current_detector = None
|
||||
current_model_identifier = None
|
||||
if model_identifier in available_models_info:
|
||||
del available_models_info[model_identifier]
|
||||
# 这些通常是服务器端问题或模型/库的兼容性问题,所以用500
|
||||
raise HTTPException(status_code=500, detail=f"检测器 '{model_identifier}' 内部初始化失败: {e}")
|
||||
|
||||
|
||||
def get_active_detector() -> Optional[RFDETRDetector]:
|
||||
"""获取当前激活的检测器实例。"""
|
||||
return current_detector
|
||||
|
||||
|
||||
def get_active_model_identifier() -> Optional[str]:
|
||||
"""获取当前激活的模型标识符。"""
|
||||
return current_model_identifier
|
||||
|
||||
|
||||
def scan_and_load_available_models():
|
||||
"""扫描配置目录,加载所有有效的模型配置。"""
|
||||
global available_models_info
|
||||
available_models_info = {}
|
||||
# print(f"扫描目录 '{BASE_CONFIG_DIR}' 以查找配置文件...") # 减少日志冗余
|
||||
if not os.path.exists(BASE_CONFIG_DIR):
|
||||
# print(f"配置目录 '{BASE_CONFIG_DIR}' 不存在,跳过扫描。")
|
||||
return
|
||||
for filename in os.listdir(BASE_CONFIG_DIR):
|
||||
if filename.endswith(".json"):
|
||||
model_identifier = filename[:-5]
|
||||
# print(f"找到配置文件: {filename},模型标识符: {model_identifier}")
|
||||
config_data = load_config(model_identifier)
|
||||
if config_data:
|
||||
model_pth = config_data.get('model_pth_filename')
|
||||
model_full_path = os.path.join(BASE_MODEL_DIR, model_pth) if model_pth else None
|
||||
if model_pth and os.path.exists(model_full_path):
|
||||
available_models_info[model_identifier] = config_data
|
||||
# print(f"模型配置 '{model_identifier}' 加载成功。")
|
||||
# elif not model_pth:
|
||||
# print(f"警告:模型 '{model_identifier}' 的配置文件中未指定 'model_pth_filename'。")
|
||||
# else:
|
||||
# print(f"警告:模型 '{model_identifier}' 的模型文件 '{model_full_path}' 未找到。")
|
||||
# else:
|
||||
# print(f"警告:无法加载模型 '{model_identifier}' 的配置文件。")
|
||||
print(f"可用模型配置已扫描/更新: {list(available_models_info.keys())}")
|
||||
|
||||
|
||||
# --- Extracted Startup Logic (to be called by the main app) ---
|
||||
async def initialize_default_model_on_startup():
|
||||
"""应用启动时,扫描并加载可用模型,尝试激活第一个。"""
|
||||
print("执行模型管理模块的启动初始化...")
|
||||
scan_and_load_available_models()
|
||||
global current_model_identifier, current_detector
|
||||
|
||||
if available_models_info:
|
||||
first_model = sorted(list(available_models_info.keys()))[0]
|
||||
print(f"尝试将第一个可用模型 '{first_model}' 设置为活动模型。")
|
||||
try:
|
||||
initialize_detector(first_model)
|
||||
print(f"默认模型 '{current_model_identifier}' 加载并激活成功。")
|
||||
except HTTPException as e: # 捕获 initialize_detector 抛出的HTTPException
|
||||
print(f"加载默认模型 '{first_model}' 失败: {e.detail} (状态码: {e.status_code})")
|
||||
current_model_identifier = None
|
||||
current_detector = None
|
||||
# 不需要再捕获 Exception as e,因为 initialize_detector 已经处理并转换为 HTTPException
|
||||
else:
|
||||
print("没有可用的模型配置,服务器启动但无默认模型激活。")
|
||||
current_model_identifier = None
|
||||
current_detector = None
|
||||
|
||||
|
||||
# --- Pydantic Models for Request/Response ---
|
||||
|
||||
class ModelIdentifier(BaseModel):
|
||||
model_identifier: str
|
||||
|
||||
|
||||
# Define a standard response model for OpenAPI documentation
|
||||
class StandardResponse(BaseModel):
|
||||
code: int
|
||||
data: Optional[Any] = None
|
||||
message: str
|
||||
|
||||
|
||||
# --- API Endpoints ---
|
||||
|
||||
@model_management_router.post("/upload_model_and_config/", response_model=StandardResponse)
|
||||
async def upload_model_and_config(
|
||||
model_identifier_form: str = Form(..., alias="model_identifier"),
|
||||
config_file: UploadFile = File(...),
|
||||
model_file: UploadFile = File(...)
|
||||
):
|
||||
"""
|
||||
上传模型文件 (.pth) 和配置文件 (.json)。
|
||||
- **model_identifier**: 模型的唯一名称 (例如 "人车检测")。配置文件将以此名称保存 (e.g., "人车检测.json")。
|
||||
- **config_file**: JSON 配置文件。
|
||||
- **model_file**: Pytorch 模型文件 (.pth)。其文件名必须与配置文件中 'model_pth_filename' 字段指定的一致。
|
||||
"""
|
||||
config_filename = f"{model_identifier_form}.json"
|
||||
config_path = os.path.join(BASE_CONFIG_DIR, config_filename)
|
||||
model_path = None # 在 try 块外部声明,以便 finally 和 except 中可用
|
||||
global current_model_identifier, current_detector # current_detector 实际上在这里主要由 initialize_detector 设置
|
||||
|
||||
try:
|
||||
print(f"开始上传模型 '{model_identifier_form}'...")
|
||||
config_content = await config_file.read()
|
||||
try:
|
||||
config_data = json.loads(config_content.decode('utf-8'))
|
||||
except json.JSONDecodeError:
|
||||
# raise HTTPException(status_code=400, detail="无效的JSON配置文件格式,请检查JSON语法。")
|
||||
return JSONResponse(status_code=400,
|
||||
content=Response.error(message="无效的JSON配置文件格式,请检查JSON语法。", code=400))
|
||||
except UnicodeDecodeError:
|
||||
# raise HTTPException(status_code=400, detail="配置文件编码错误,请确保为UTF-8编码。")
|
||||
return JSONResponse(status_code=400,
|
||||
content=Response.error(message="配置文件编码错误,请确保为UTF-8编码。", code=400))
|
||||
|
||||
if 'model_pth_filename' not in config_data:
|
||||
# raise HTTPException(status_code=400, detail="配置文件中必须包含 'model_pth_filename' 字段。")
|
||||
return JSONResponse(status_code=400,
|
||||
content=Response.error(message="配置文件中必须包含 'model_pth_filename' 字段。",
|
||||
code=400))
|
||||
|
||||
target_model_filename = config_data['model_pth_filename']
|
||||
if not isinstance(target_model_filename, str) or not target_model_filename.endswith(".pth"):
|
||||
# raise HTTPException(status_code=400, detail="配置文件中的 'model_pth_filename' 必须是有效的 .pth 文件名字符串。")
|
||||
return JSONResponse(status_code=400, content=Response.error(
|
||||
message="配置文件中的 'model_pth_filename' 必须是有效的 .pth 文件名字符串。", code=400))
|
||||
|
||||
if model_file.filename != target_model_filename:
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail=f"上传的模型文件名 '{model_file.filename}' 与配置文件中指定的 '{target_model_filename}' 不匹配。"
|
||||
# )
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=Response.error(
|
||||
message=f"上传的模型文件名 '{model_file.filename}' 与配置文件中指定的 '{target_model_filename}' 不匹配。",
|
||||
code=400
|
||||
)
|
||||
)
|
||||
|
||||
with open(config_path, 'wb') as f:
|
||||
f.write(config_content)
|
||||
print(f"配置文件 '{config_path}' 已保存。")
|
||||
|
||||
model_path = os.path.join(BASE_MODEL_DIR, target_model_filename)
|
||||
with open(model_path, "wb") as buffer:
|
||||
shutil.copyfileobj(model_file.file, buffer)
|
||||
print(f"模型文件 '{model_path}' 已保存。")
|
||||
|
||||
# 在尝试初始化之前,将配置数据添加到 available_models_info
|
||||
# 这样 initialize_detector 即使在缓存未命中的情况下也能通过 load_config 找到它
|
||||
available_models_info[model_identifier_form] = config_data
|
||||
|
||||
try:
|
||||
print(f"尝试初始化并激活新上传的模型: '{model_identifier_form}'")
|
||||
initialize_detector(model_identifier_form) # 此函数会设置全局的 current_detector 和 current_model_identifier
|
||||
print(f"新上传的模型 '{model_identifier_form}' 验证并激活成功。")
|
||||
# 成功后,available_models_info 已包含此模型,current_detector 和 current_model_identifier 已更新
|
||||
# 无需在此处再次调用 scan_and_load_available_models(),因为它会重新扫描所有,可能覆盖内存中的一些状态或引入不必要的IO
|
||||
except HTTPException as e: # 从 initialize_detector 抛出的错误
|
||||
print(f"初始化新上传的模型 '{model_identifier_form}' 失败: {e.detail}")
|
||||
# 清理已保存的文件和内存中的条目
|
||||
if os.path.exists(config_path): os.remove(config_path)
|
||||
if model_path and os.path.exists(model_path): os.remove(model_path)
|
||||
if model_identifier_form in available_models_info: del available_models_info[model_identifier_form]
|
||||
# scan_and_load_available_models() # 失败后重新扫描是好的,以确保 available_models_info 准确
|
||||
# 但如果 initialize_detector 内部已经删除了 available_models_info 中的条目,这里可能不需要
|
||||
# 为保持一致性,如果上面删除了,这里重新扫描一下比较稳妥
|
||||
scan_and_load_available_models()
|
||||
|
||||
# 重新抛出为 422,并包含原始错误信息
|
||||
# raise HTTPException(status_code=422,
|
||||
# detail=f"模型 '{model_identifier_form}' 已上传但初始化失败,请检查模型或配置内容。原始错误: {e.detail}")
|
||||
return JSONResponse(status_code=422,
|
||||
content=Response.error(
|
||||
message=f"模型 '{model_identifier_form}' 已上传但初始化失败,请检查模型或配置内容。原始错误: {e.detail}",
|
||||
code=422
|
||||
))
|
||||
|
||||
# 如果成功,确保全局模型列表是最新的(虽然 initialize_detector 已更新了 current_*,但列表可能需要刷新以供 /available_models 使用)
|
||||
# scan_and_load_available_models() # 移除这个,因为当前模型已激活,列表会在下次调用 /available_models 时刷新
|
||||
|
||||
# return UploadResponse(
|
||||
# message=f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。",
|
||||
# model_identifier=model_identifier_form,
|
||||
# config_filename=config_filename,
|
||||
# model_filename=target_model_filename
|
||||
# )
|
||||
return Response.success(data={
|
||||
"message": f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。", # Message is also in the wrapper
|
||||
"model_identifier": model_identifier_form,
|
||||
"config_filename": config_filename,
|
||||
"model_filename": target_model_filename
|
||||
}, message=f"模型 '{model_identifier_form}' 和配置文件上传成功并已激活。")
|
||||
|
||||
except HTTPException as e: # 捕获直接由 FastAPI 验证或其他地方抛出的 HTTPException
|
||||
# This might catch exceptions from initialize_detector if they are not caught internally by the above try-except for initialize_detector
|
||||
return JSONResponse(status_code=e.status_code, content=Response.error(message=e.detail, code=e.status_code))
|
||||
except Exception as e:
|
||||
identifier_for_log = model_identifier_form if 'model_identifier_form' in locals() else "unknown"
|
||||
print(f"上传模型 '{identifier_for_log}' 过程中发生意外的严重错误: {e}")
|
||||
# import traceback
|
||||
# traceback.print_exc()
|
||||
# 尝试清理,以防文件部分写入
|
||||
if config_path and os.path.exists(config_path):
|
||||
os.remove(config_path)
|
||||
if model_path and os.path.exists(model_path):
|
||||
os.remove(model_path)
|
||||
if 'model_identifier_form' in locals() and model_identifier_form in available_models_info:
|
||||
del available_models_info[model_identifier_form]
|
||||
scan_and_load_available_models() # 出错后务必刷新列表
|
||||
|
||||
# raise HTTPException(status_code=500, detail=f"上传模型 '{identifier_for_log}' 过程中发生内部服务器错误: {str(e)}")
|
||||
return JSONResponse(status_code=500, content=Response.error(
|
||||
message=f"上传模型 '{identifier_for_log}' 过程中发生内部服务器错误: {str(e)}", code=500))
|
||||
finally:
|
||||
if config_file: await config_file.close()
|
||||
if model_file: await model_file.close()
|
||||
|
||||
|
||||
@model_management_router.post("/select_model/", response_model=StandardResponse)
|
||||
# @model_management_router.post("/select_model/{model_identifier_path}", response_model=StandardResponse)
|
||||
async def select_model(model_identifier_path: str):
|
||||
"""
|
||||
根据提供的标识符选择并激活一个已上传的模型。
|
||||
路径参数 `model_identifier_path` 即为模型的唯一名称。
|
||||
"""
|
||||
global current_model_identifier, current_detector
|
||||
|
||||
# 总是先扫描以获取最新的可用模型列表,以防外部文件更改
|
||||
scan_and_load_available_models()
|
||||
if model_identifier_path not in available_models_info:
|
||||
# raise HTTPException(status_code=404, detail=f"模型标识符 '{model_identifier_path}' 未在可用配置中找到。请确保已上传或名称正确。")
|
||||
return JSONResponse(status_code=404, content=Response.error(
|
||||
message=f"模型标识符 '{model_identifier_path}' 未在可用配置中找到。请确保已上传或名称正确。", code=404))
|
||||
|
||||
if current_model_identifier == model_identifier_path and current_detector is not None:
|
||||
print(f"模型 '{model_identifier_path}' 已经是活动模型。")
|
||||
# return SelectModelResponse(
|
||||
# message=f"模型 '{model_identifier_path}' 已是当前活动模型。",
|
||||
# active_model=current_model_identifier
|
||||
# )
|
||||
return Response.success(data={
|
||||
"active_model": current_model_identifier
|
||||
}, message=f"模型 '{model_identifier_path}' 已是当前活动模型。")
|
||||
try:
|
||||
print(f"尝试激活模型: {model_identifier_path}")
|
||||
initialize_detector(model_identifier_path)
|
||||
print(f"模型 '{current_model_identifier}' 已成功激活。")
|
||||
# return SelectModelResponse(
|
||||
# message=f"模型 '{model_identifier_path}' 启动成功。",
|
||||
# active_model=current_model_identifier
|
||||
# )
|
||||
return Response.success(data={
|
||||
"active_model": current_model_identifier
|
||||
}, message=f"模型 '{model_identifier_path}' 启动成功。")
|
||||
except HTTPException as e:
|
||||
# initialize_detector 内部已经处理了 current_model_identifier 和 current_detector 的清理
|
||||
# 以及 available_models_info 中对应条目的移除(如果是其内部错误)
|
||||
print(f"激活模型 '{model_identifier_path}' 失败: {e.detail} (原始状态码: {e.status_code})")
|
||||
# 此处不再需要手动清理 available_models_info,因为 initialize_detector 如果因内部错误(非文件找不到)
|
||||
# 导致无法实例化 RFDETRDetector,它会自己删除条目。
|
||||
# 如果是文件找不到 (404 from initialize_detector),条目可能还在,但下次 scan 会处理。
|
||||
scan_and_load_available_models() # 确保 select 失败后,列表也是最新的
|
||||
# raise HTTPException(status_code=e.status_code, detail=e.detail) # 重新抛出原始的 HTTPException
|
||||
return JSONResponse(status_code=e.status_code, content=Response.error(message=e.detail, code=e.status_code))
|
||||
# 不需要再捕获 Exception as e,因为 initialize_detector 已经处理并转换为 HTTPException
|
||||
|
||||
|
||||
@model_management_router.get("/available_models/", response_model=StandardResponse)
|
||||
async def get_available_models():
|
||||
"""列出所有当前可用的模型标识符。"""
|
||||
scan_and_load_available_models()
|
||||
# return list(available_models_info.keys())
|
||||
return Response.success(data=list(available_models_info.keys()), message="成功获取可用模型列表。")
|
||||
|
||||
|
||||
@model_management_router.get("/current_model/", response_model=StandardResponse)
|
||||
async def get_current_model_endpoint():
|
||||
"""获取当前激活的模型标识符。"""
|
||||
# return current_model_identifier
|
||||
if current_model_identifier:
|
||||
return Response.success(data=current_model_identifier, message="成功获取当前激活的模型。")
|
||||
else:
|
||||
return Response.success(data=None, message="当前没有激活的模型。")
|
Reference in New Issue
Block a user